Unverified Commit 2376e9c9 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Rename RNNT loss C++ parameters (#1602)

parent 6a8ecd98
......@@ -10,21 +10,21 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
torch::autograd::AutogradContext* ctx,
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
torch::Tensor undef;
auto result = rnnt_loss(
logits,
targets,
src_lengths,
tgt_lengths,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_smax,
fused_log_softmax,
reuse_logits_for_grads);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
......@@ -47,21 +47,21 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply(
logits,
targets,
src_lengths,
tgt_lengths,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_smax,
fused_log_softmax,
reuse_logits_for_grads);
return std::make_tuple(results[0], results[1]);
}
......
......@@ -4,11 +4,11 @@
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
......@@ -16,11 +16,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
return op.call(
logits,
targets,
src_lengths,
tgt_lengths,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_smax,
fused_log_softmax,
reuse_logits_for_grads);
}
......@@ -28,10 +28,10 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss(Tensor logits,"
"Tensor targets,"
"Tensor src_lengths,"
"Tensor tgt_lengths,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp,"
"bool fused_log_smax=True,"
"bool fused_log_softmax=True,"
"bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)");
}
......@@ -5,9 +5,9 @@
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_smax,
bool fused_log_softmax,
bool reuse_logits_for_grads);
......@@ -4,8 +4,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss_alphas(Tensor logits,"
"Tensor targets,"
"Tensor src_lengths,"
"Tensor tgt_lengths,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp) -> Tensor");
}
......@@ -4,8 +4,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss_betas(Tensor logits,"
"Tensor targets,"
"Tensor src_lengths,"
"Tensor tgt_lengths,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp) -> Tensor");
}
......@@ -9,20 +9,20 @@ namespace cpu {
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
TORCH_CHECK(
logits.device().type() == src_lengths.device().type(),
logits.device().type() == logit_lengths.device().type(),
"logits and logit_lengths must be on the same device");
TORCH_CHECK(
logits.device().type() == tgt_lengths.device().type(),
logits.device().type() == target_lengths.device().type(),
"logits and target_lengths must be on the same device");
TORCH_CHECK(
......@@ -30,28 +30,31 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"logits must be float32 or float16 (half) type");
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
TORCH_CHECK(
src_lengths.dtype() == torch::kInt32, "logit_lengths must be int32 type");
logit_lengths.dtype() == torch::kInt32,
"logit_lengths must be int32 type");
TORCH_CHECK(
tgt_lengths.dtype() == torch::kInt32,
target_lengths.dtype() == torch::kInt32,
"target_lengths must be int32 type");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(src_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(tgt_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK(
logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(
target_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
TORCH_CHECK(src_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(tgt_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK(
src_lengths.size(0) == logits.size(0),
logit_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths");
TORCH_CHECK(
tgt_lengths.size(0) == logits.size(0),
target_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths");
TORCH_CHECK(
targets.size(0) == logits.size(0),
......@@ -62,24 +65,24 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"blank must be within [0, logits.shape[-1])");
TORCH_CHECK(
logits.size(1) == at::max(src_lengths).item().toInt(),
logits.size(1) == at::max(logit_lengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
logits.size(2) == at::max(tgt_lengths).item().toInt() + 1,
logits.size(2) == at::max(target_lengths).item().toInt() + 1,
"output length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(tgt_lengths).item().toInt(),
targets.size(1) == at::max(target_lengths).item().toInt(),
"target length mismatch");
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_smax;
options.fusedLogSmax_ = fused_log_softmax;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
......@@ -121,8 +124,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
......@@ -133,8 +136,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr
......
......@@ -8,13 +8,13 @@ namespace cpu {
torch::Tensor compute_alphas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
......@@ -55,8 +55,8 @@ torch::Tensor compute_alphas(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>());
return alphas;
}
......
......@@ -8,13 +8,13 @@ namespace cpu {
torch::Tensor compute_betas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
......@@ -25,7 +25,7 @@ torch::Tensor compute_betas(
options.device_ = CPU;
torch::Tensor costs = torch::empty(
tgt_lengths.size(0),
target_lengths.size(0),
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor betas = torch::zeros(
......@@ -59,8 +59,8 @@ torch::Tensor compute_betas(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*betas=*/betas.data_ptr<float>());
return betas;
......
......@@ -10,20 +10,20 @@ namespace gpu {
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
TORCH_CHECK(
logits.device().type() == src_lengths.device().type(),
logits.device().type() == logit_lengths.device().type(),
"logits and logit_lengths must be on the same device");
TORCH_CHECK(
logits.device().type() == tgt_lengths.device().type(),
logits.device().type() == target_lengths.device().type(),
"logits and target_lengths must be on the same device");
TORCH_CHECK(
......@@ -31,28 +31,31 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"logits must be float32 or float16 (half) type");
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
TORCH_CHECK(
src_lengths.dtype() == torch::kInt32, "logit_lengths must be int32 type");
logit_lengths.dtype() == torch::kInt32,
"logit_lengths must be int32 type");
TORCH_CHECK(
tgt_lengths.dtype() == torch::kInt32,
target_lengths.dtype() == torch::kInt32,
"target_lengths must be int32 type");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(src_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(tgt_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK(
logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(
target_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
TORCH_CHECK(src_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(tgt_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK(
src_lengths.size(0) == logits.size(0),
logit_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths");
TORCH_CHECK(
tgt_lengths.size(0) == logits.size(0),
target_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths");
TORCH_CHECK(
targets.size(0) == logits.size(0),
......@@ -63,24 +66,24 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"blank must be within [0, logits.shape[-1])");
TORCH_CHECK(
logits.size(1) == at::max(src_lengths).item().toInt(),
logits.size(1) == at::max(logit_lengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
logits.size(2) == at::max(tgt_lengths).item().toInt() + 1,
logits.size(2) == at::max(target_lengths).item().toInt() + 1,
"output length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(tgt_lengths).item().toInt(),
targets.size(1) == at::max(target_lengths).item().toInt(),
"target length mismatch");
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_smax;
options.fusedLogSmax_ = fused_log_softmax;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
......@@ -124,8 +127,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
......@@ -136,8 +139,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr
......
......@@ -9,13 +9,13 @@ namespace gpu {
torch::Tensor compute_alphas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
......@@ -58,8 +58,8 @@ torch::Tensor compute_alphas(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>());
return alphas;
}
......
......@@ -9,13 +9,13 @@ namespace gpu {
torch::Tensor compute_betas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
......@@ -28,7 +28,7 @@ torch::Tensor compute_betas(
options.device_ = GPU;
torch::Tensor costs = torch::empty(
tgt_lengths.size(0),
target_lengths.size(0),
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor betas = torch::zeros(
......@@ -62,8 +62,8 @@ torch::Tensor compute_betas(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*betas=*/betas.data_ptr<float>());
return betas;
......
......@@ -51,11 +51,11 @@ def rnnt_loss(
costs, gradients = torch.ops.torchaudio.rnnt_loss(
logits=logits,
targets=targets,
src_lengths=logit_lengths,
tgt_lengths=target_lengths,
logit_lengths=logit_lengths,
target_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_smax=fused_log_softmax,
fused_log_softmax=fused_log_softmax,
reuse_logits_for_grads=reuse_logits_for_grads,)
return costs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment