compute.cu 5.32 KB
Newer Older
Caroline Chen's avatar
Caroline Chen committed
1
#include <c10/cuda/CUDAStream.h>
2
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
3
#include <torch/types.h>
Caroline Chen's avatar
Caroline Chen committed
4
5
6
7
8
9

namespace torchaudio {
namespace rnnt {
namespace gpu {

// Entry point into RNNT Loss
10
std::tuple<torch::Tensor, std::optional<torch::Tensor>> compute(
Caroline Chen's avatar
Caroline Chen committed
11
12
    torch::Tensor& logits,
    const torch::Tensor& targets,
13
14
    const torch::Tensor& logit_lengths,
    const torch::Tensor& target_lengths,
Caroline Chen's avatar
Caroline Chen committed
15
    int64_t blank,
16
17
    double clamp,
    bool fused_log_softmax = true) {
18
19
20
21
  TORCH_CHECK(
      logits.device().type() == targets.device().type(),
      "logits and targets must be on the same device");
  TORCH_CHECK(
22
      logits.device().type() == logit_lengths.device().type(),
23
24
      "logits and logit_lengths must be on the same device");
  TORCH_CHECK(
25
      logits.device().type() == target_lengths.device().type(),
26
27
28
29
30
31
32
      "logits and target_lengths must be on the same device");

  TORCH_CHECK(
      logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
      "logits must be float32 or float16 (half) type");
  TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
  TORCH_CHECK(
33
34
      logit_lengths.dtype() == torch::kInt32,
      "logit_lengths must be int32 type");
35
  TORCH_CHECK(
36
      target_lengths.dtype() == torch::kInt32,
37
38
39
40
      "target_lengths must be int32 type");

  TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
  TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
41
42
43
44
  TORCH_CHECK(
      logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
  TORCH_CHECK(
      target_lengths.is_contiguous(), "target_lengths must be contiguous");
45
46
47
48
49

  TORCH_CHECK(
      logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
  TORCH_CHECK(
      targets.dim() == 2, "targets must be 2-D (batch, max target length)");
50
51
  TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
  TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
52
53

  TORCH_CHECK(
54
      logit_lengths.size(0) == logits.size(0),
55
56
      "batch dimension mismatch between logits and logit_lengths");
  TORCH_CHECK(
57
      target_lengths.size(0) == logits.size(0),
58
59
60
61
62
63
64
65
66
67
      "batch dimension mismatch between logits and target_lengths");
  TORCH_CHECK(
      targets.size(0) == logits.size(0),
      "batch dimension mismatch between logits and targets");

  TORCH_CHECK(
      blank >= 0 && blank < logits.size(-1),
      "blank must be within [0, logits.shape[-1])");

  TORCH_CHECK(
68
      logits.size(1) == at::max(logit_lengths).item().toInt(),
69
70
      "input length mismatch");
  TORCH_CHECK(
71
      logits.size(2) == at::max(target_lengths).item().toInt() + 1,
72
73
      "output length mismatch");
  TORCH_CHECK(
74
      targets.size(1) == at::max(target_lengths).item().toInt(),
75
76
      "target length mismatch");

Caroline Chen's avatar
Caroline Chen committed
77
  Options options;
78
79
  options.batchSize_ = logit_lengths.size(0);
  options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
Caroline Chen's avatar
Caroline Chen committed
80
81
82
83
84
  options.maxSrcLen_ = logits.size(1);
  options.maxTgtLen_ = logits.size(2);
  options.numTargets_ = logits.size(3);
  options.blank_ = blank;
  options.clamp_ = clamp;
85
  options.fusedLogSmax_ = fused_log_softmax;
Caroline Chen's avatar
Caroline Chen committed
86

87
  TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
Caroline Chen's avatar
Caroline Chen committed
88
89
90
91
92
93
94
  options.stream_ = at::cuda::getCurrentCUDAStream();
  cudaSetDevice(logits.get_device());
  options.device_ = GPU;

  torch::Tensor costs = torch::empty(
      options.batchSize_ * options.nHypos_,
      torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
95
  std::optional<torch::Tensor> gradients = torch::zeros_like(logits);
Caroline Chen's avatar
Caroline Chen committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

  torch::Tensor int_workspace = torch::empty(
      IntWorkspace::ComputeSizeFromOptions(options),
      torch::TensorOptions()
          .device(logits.device())
          .dtype(torch::ScalarType::Int));

  torch::Tensor float_workspace = torch::empty(
      DtypeWorkspace<float>::ComputeSizeFromOptions(options),
      torch::TensorOptions()
          .device(logits.device())
          .dtype(torch::ScalarType::Float));

  Workspace<float> workspace(
      /*options=*/options,
      /*dtype_data=*/float_workspace.data_ptr<float>(),
      /*dtype_size=*/float_workspace.numel(),
      /*int_data=*/int_workspace.data_ptr<int>(),
      /*int_size=*/int_workspace.numel());

  switch (logits.scalar_type()) {
    case torch::ScalarType::Float: {
      Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
          /*workspace=*/workspace,
          /*logits=*/logits.data_ptr<float>(),
          /*targets=*/targets.data_ptr<int>(),
122
123
          /*logit_lengths=*/logit_lengths.data_ptr<int>(),
          /*target_lengths=*/target_lengths.data_ptr<int>(),
Caroline Chen's avatar
Caroline Chen committed
124
          /*costs=*/costs.data_ptr<float>(),
125
          /*gradients=*/gradients->data_ptr<float>());
Caroline Chen's avatar
Caroline Chen committed
126
127
128
129
130
131
132
      break;
    }
    case torch::ScalarType::Half: {
      Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
          /*workspace=*/workspace,
          /*logits=*/logits.data_ptr<c10::Half>(),
          /*targets=*/targets.data_ptr<int>(),
133
134
          /*logit_lengths=*/logit_lengths.data_ptr<int>(),
          /*target_lengths=*/target_lengths.data_ptr<int>(),
Caroline Chen's avatar
Caroline Chen committed
135
          /*costs=*/costs.data_ptr<c10::Half>(),
136
          /*gradients=*/gradients->data_ptr<c10::Half>());
Caroline Chen's avatar
Caroline Chen committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
      break;
    }
    default: {
      break;
    }
  };

  return std::make_tuple(costs, gradients);
}

TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
  m.impl("rnnt_loss", &compute);
}

} // namespace gpu
} // namespace rnnt
} // namespace torchaudio