Unverified Commit 2376e9c9 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Rename RNNT loss C++ parameters (#1602)

parent 6a8ecd98
...@@ -10,21 +10,21 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> { ...@@ -10,21 +10,21 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
torch::autograd::AutogradContext* ctx, torch::autograd::AutogradContext* ctx,
torch::Tensor& logits, torch::Tensor& logits,
const torch::Tensor& targets, const torch::Tensor& targets,
const torch::Tensor& src_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& tgt_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_smax = true, bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) { bool reuse_logits_for_grads = true) {
torch::Tensor undef; torch::Tensor undef;
auto result = rnnt_loss( auto result = rnnt_loss(
logits, logits,
targets, targets,
src_lengths, logit_lengths,
tgt_lengths, target_lengths,
blank, blank,
clamp, clamp,
fused_log_smax, fused_log_softmax,
reuse_logits_for_grads); reuse_logits_for_grads);
auto costs = std::get<0>(result); auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef); auto grads = std::get<1>(result).value_or(undef);
...@@ -47,21 +47,21 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> { ...@@ -47,21 +47,21 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd( std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
torch::Tensor& logits, torch::Tensor& logits,
const torch::Tensor& targets, const torch::Tensor& targets,
const torch::Tensor& src_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& tgt_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_smax = true, bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) { bool reuse_logits_for_grads = true) {
at::AutoDispatchBelowADInplaceOrView guard; at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply( auto results = RNNTLossFunction::apply(
logits, logits,
targets, targets,
src_lengths, logit_lengths,
tgt_lengths, target_lengths,
blank, blank,
clamp, clamp,
fused_log_smax, fused_log_softmax,
reuse_logits_for_grads); reuse_logits_for_grads);
return std::make_tuple(results[0], results[1]); return std::make_tuple(results[0], results[1]);
} }
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss( std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits, torch::Tensor& logits,
const torch::Tensor& targets, const torch::Tensor& targets,
const torch::Tensor& src_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& tgt_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_smax = true, bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) { bool reuse_logits_for_grads = true) {
static auto op = torch::Dispatcher::singleton() static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "") .findSchemaOrThrow("torchaudio::rnnt_loss", "")
...@@ -16,11 +16,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss( ...@@ -16,11 +16,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
return op.call( return op.call(
logits, logits,
targets, targets,
src_lengths, logit_lengths,
tgt_lengths, target_lengths,
blank, blank,
clamp, clamp,
fused_log_smax, fused_log_softmax,
reuse_logits_for_grads); reuse_logits_for_grads);
} }
...@@ -28,10 +28,10 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -28,10 +28,10 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def( m.def(
"rnnt_loss(Tensor logits," "rnnt_loss(Tensor logits,"
"Tensor targets," "Tensor targets,"
"Tensor src_lengths," "Tensor logit_lengths,"
"Tensor tgt_lengths," "Tensor target_lengths,"
"int blank," "int blank,"
"float clamp," "float clamp,"
"bool fused_log_smax=True," "bool fused_log_softmax=True,"
"bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)"); "bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)");
} }
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss( std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits, torch::Tensor& logits,
const torch::Tensor& targets, const torch::Tensor& targets,
const torch::Tensor& src_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& tgt_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_smax, bool fused_log_softmax,
bool reuse_logits_for_grads); bool reuse_logits_for_grads);
...@@ -4,8 +4,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -4,8 +4,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def( m.def(
"rnnt_loss_alphas(Tensor logits," "rnnt_loss_alphas(Tensor logits,"
"Tensor targets," "Tensor targets,"
"Tensor src_lengths," "Tensor logit_lengths,"
"Tensor tgt_lengths," "Tensor target_lengths,"
"int blank," "int blank,"
"float clamp) -> Tensor"); "float clamp) -> Tensor");
} }
...@@ -4,8 +4,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -4,8 +4,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def( m.def(
"rnnt_loss_betas(Tensor logits," "rnnt_loss_betas(Tensor logits,"
"Tensor targets," "Tensor targets,"
"Tensor src_lengths," "Tensor logit_lengths,"
"Tensor tgt_lengths," "Tensor target_lengths,"
"int blank," "int blank,"
"float clamp) -> Tensor"); "float clamp) -> Tensor");
} }
...@@ -9,20 +9,20 @@ namespace cpu { ...@@ -9,20 +9,20 @@ namespace cpu {
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor& logits, torch::Tensor& logits,
const torch::Tensor& targets, const torch::Tensor& targets,
const torch::Tensor& src_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& tgt_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_smax = true, bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) { bool reuse_logits_for_grads = true) {
TORCH_CHECK( TORCH_CHECK(
logits.device().type() == targets.device().type(), logits.device().type() == targets.device().type(),
"logits and targets must be on the same device"); "logits and targets must be on the same device");
TORCH_CHECK( TORCH_CHECK(
logits.device().type() == src_lengths.device().type(), logits.device().type() == logit_lengths.device().type(),
"logits and logit_lengths must be on the same device"); "logits and logit_lengths must be on the same device");
TORCH_CHECK( TORCH_CHECK(
logits.device().type() == tgt_lengths.device().type(), logits.device().type() == target_lengths.device().type(),
"logits and target_lengths must be on the same device"); "logits and target_lengths must be on the same device");
TORCH_CHECK( TORCH_CHECK(
...@@ -30,28 +30,31 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -30,28 +30,31 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"logits must be float32 or float16 (half) type"); "logits must be float32 or float16 (half) type");
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
TORCH_CHECK( TORCH_CHECK(
src_lengths.dtype() == torch::kInt32, "logit_lengths must be int32 type"); logit_lengths.dtype() == torch::kInt32,
"logit_lengths must be int32 type");
TORCH_CHECK( TORCH_CHECK(
tgt_lengths.dtype() == torch::kInt32, target_lengths.dtype() == torch::kInt32,
"target_lengths must be int32 type"); "target_lengths must be int32 type");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(src_lengths.is_contiguous(), "logit_lengths must be contiguous"); TORCH_CHECK(
TORCH_CHECK(tgt_lengths.is_contiguous(), "target_lengths must be contiguous"); logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(
target_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK( TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
TORCH_CHECK( TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)"); targets.dim() == 2, "targets must be 2-D (batch, max target length)");
TORCH_CHECK(src_lengths.dim() == 1, "logit_lengths must be 1-D"); TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(tgt_lengths.dim() == 1, "target_lengths must be 1-D"); TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK( TORCH_CHECK(
src_lengths.size(0) == logits.size(0), logit_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths"); "batch dimension mismatch between logits and logit_lengths");
TORCH_CHECK( TORCH_CHECK(
tgt_lengths.size(0) == logits.size(0), target_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths"); "batch dimension mismatch between logits and target_lengths");
TORCH_CHECK( TORCH_CHECK(
targets.size(0) == logits.size(0), targets.size(0) == logits.size(0),
...@@ -62,24 +65,24 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -62,24 +65,24 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"blank must be within [0, logits.shape[-1])"); "blank must be within [0, logits.shape[-1])");
TORCH_CHECK( TORCH_CHECK(
logits.size(1) == at::max(src_lengths).item().toInt(), logits.size(1) == at::max(logit_lengths).item().toInt(),
"input length mismatch"); "input length mismatch");
TORCH_CHECK( TORCH_CHECK(
logits.size(2) == at::max(tgt_lengths).item().toInt() + 1, logits.size(2) == at::max(target_lengths).item().toInt() + 1,
"output length mismatch"); "output length mismatch");
TORCH_CHECK( TORCH_CHECK(
targets.size(1) == at::max(tgt_lengths).item().toInt(), targets.size(1) == at::max(target_lengths).item().toInt(),
"target length mismatch"); "target length mismatch");
Options options; Options options;
options.batchSize_ = src_lengths.size(0); options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1); options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2); options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3); options.numTargets_ = logits.size(3);
options.blank_ = blank; options.blank_ = blank;
options.clamp_ = clamp; options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_smax; options.fusedLogSmax_ = fused_log_softmax;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU; options.device_ = CPU;
...@@ -121,8 +124,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -121,8 +124,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/workspace, /*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(), /*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(), /*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(), /*costs=*/costs.data_ptr<float>(),
/*gradients=*/ /*gradients=*/
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>()); (gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
...@@ -133,8 +136,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -133,8 +136,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/workspace, /*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(), /*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(), /*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(), /*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/ /*gradients=*/
(gradients == c10::nullopt) ? nullptr (gradients == c10::nullopt) ? nullptr
......
...@@ -8,13 +8,13 @@ namespace cpu { ...@@ -8,13 +8,13 @@ namespace cpu {
torch::Tensor compute_alphas( torch::Tensor compute_alphas(
const torch::Tensor& logits, const torch::Tensor& logits,
const torch::Tensor& targets, const torch::Tensor& targets,
const torch::Tensor& src_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& tgt_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp) { double clamp) {
Options options; Options options;
options.batchSize_ = src_lengths.size(0); options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1); options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2); options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3); options.numTargets_ = logits.size(3);
...@@ -55,8 +55,8 @@ torch::Tensor compute_alphas( ...@@ -55,8 +55,8 @@ torch::Tensor compute_alphas(
/*workspace=*/workspace, /*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(), /*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(), /*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>()); /*alphas=*/alphas.data_ptr<float>());
return alphas; return alphas;
} }
......
...@@ -8,13 +8,13 @@ namespace cpu { ...@@ -8,13 +8,13 @@ namespace cpu {
torch::Tensor compute_betas( torch::Tensor compute_betas(
const torch::Tensor& logits, const torch::Tensor& logits,
const torch::Tensor& targets, const torch::Tensor& targets,
const torch::Tensor& src_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& tgt_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp) { double clamp) {
Options options; Options options;
options.batchSize_ = src_lengths.size(0); options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1); options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2); options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3); options.numTargets_ = logits.size(3);
...@@ -25,7 +25,7 @@ torch::Tensor compute_betas( ...@@ -25,7 +25,7 @@ torch::Tensor compute_betas(
options.device_ = CPU; options.device_ = CPU;
torch::Tensor costs = torch::empty( torch::Tensor costs = torch::empty(
tgt_lengths.size(0), target_lengths.size(0),
torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor betas = torch::zeros( torch::Tensor betas = torch::zeros(
...@@ -59,8 +59,8 @@ torch::Tensor compute_betas( ...@@ -59,8 +59,8 @@ torch::Tensor compute_betas(
/*workspace=*/workspace, /*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(), /*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(), /*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(), /*costs=*/costs.data_ptr<float>(),
/*betas=*/betas.data_ptr<float>()); /*betas=*/betas.data_ptr<float>());
return betas; return betas;
......
...@@ -10,20 +10,20 @@ namespace gpu { ...@@ -10,20 +10,20 @@ namespace gpu {
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor& logits, torch::Tensor& logits,
const torch::Tensor& targets, const torch::Tensor& targets,
const torch::Tensor& src_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& tgt_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp, double clamp,
bool fused_log_smax = true, bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) { bool reuse_logits_for_grads = true) {
TORCH_CHECK( TORCH_CHECK(
logits.device().type() == targets.device().type(), logits.device().type() == targets.device().type(),
"logits and targets must be on the same device"); "logits and targets must be on the same device");
TORCH_CHECK( TORCH_CHECK(
logits.device().type() == src_lengths.device().type(), logits.device().type() == logit_lengths.device().type(),
"logits and logit_lengths must be on the same device"); "logits and logit_lengths must be on the same device");
TORCH_CHECK( TORCH_CHECK(
logits.device().type() == tgt_lengths.device().type(), logits.device().type() == target_lengths.device().type(),
"logits and target_lengths must be on the same device"); "logits and target_lengths must be on the same device");
TORCH_CHECK( TORCH_CHECK(
...@@ -31,28 +31,31 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -31,28 +31,31 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"logits must be float32 or float16 (half) type"); "logits must be float32 or float16 (half) type");
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
TORCH_CHECK( TORCH_CHECK(
src_lengths.dtype() == torch::kInt32, "logit_lengths must be int32 type"); logit_lengths.dtype() == torch::kInt32,
"logit_lengths must be int32 type");
TORCH_CHECK( TORCH_CHECK(
tgt_lengths.dtype() == torch::kInt32, target_lengths.dtype() == torch::kInt32,
"target_lengths must be int32 type"); "target_lengths must be int32 type");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(src_lengths.is_contiguous(), "logit_lengths must be contiguous"); TORCH_CHECK(
TORCH_CHECK(tgt_lengths.is_contiguous(), "target_lengths must be contiguous"); logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(
target_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK( TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
TORCH_CHECK( TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)"); targets.dim() == 2, "targets must be 2-D (batch, max target length)");
TORCH_CHECK(src_lengths.dim() == 1, "logit_lengths must be 1-D"); TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(tgt_lengths.dim() == 1, "target_lengths must be 1-D"); TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK( TORCH_CHECK(
src_lengths.size(0) == logits.size(0), logit_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths"); "batch dimension mismatch between logits and logit_lengths");
TORCH_CHECK( TORCH_CHECK(
tgt_lengths.size(0) == logits.size(0), target_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths"); "batch dimension mismatch between logits and target_lengths");
TORCH_CHECK( TORCH_CHECK(
targets.size(0) == logits.size(0), targets.size(0) == logits.size(0),
...@@ -63,24 +66,24 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -63,24 +66,24 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
"blank must be within [0, logits.shape[-1])"); "blank must be within [0, logits.shape[-1])");
TORCH_CHECK( TORCH_CHECK(
logits.size(1) == at::max(src_lengths).item().toInt(), logits.size(1) == at::max(logit_lengths).item().toInt(),
"input length mismatch"); "input length mismatch");
TORCH_CHECK( TORCH_CHECK(
logits.size(2) == at::max(tgt_lengths).item().toInt() + 1, logits.size(2) == at::max(target_lengths).item().toInt() + 1,
"output length mismatch"); "output length mismatch");
TORCH_CHECK( TORCH_CHECK(
targets.size(1) == at::max(tgt_lengths).item().toInt(), targets.size(1) == at::max(target_lengths).item().toInt(),
"target length mismatch"); "target length mismatch");
Options options; Options options;
options.batchSize_ = src_lengths.size(0); options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1); options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2); options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3); options.numTargets_ = logits.size(3);
options.blank_ = blank; options.blank_ = blank;
options.clamp_ = clamp; options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_smax; options.fusedLogSmax_ = fused_log_softmax;
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream(); options.stream_ = at::cuda::getCurrentCUDAStream();
...@@ -124,8 +127,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -124,8 +127,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/workspace, /*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(), /*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(), /*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(), /*costs=*/costs.data_ptr<float>(),
/*gradients=*/ /*gradients=*/
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>()); (gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
...@@ -136,8 +139,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( ...@@ -136,8 +139,8 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
/*workspace=*/workspace, /*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(), /*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(), /*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(), /*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/ /*gradients=*/
(gradients == c10::nullopt) ? nullptr (gradients == c10::nullopt) ? nullptr
......
...@@ -9,13 +9,13 @@ namespace gpu { ...@@ -9,13 +9,13 @@ namespace gpu {
torch::Tensor compute_alphas( torch::Tensor compute_alphas(
const torch::Tensor& logits, const torch::Tensor& logits,
const torch::Tensor& targets, const torch::Tensor& targets,
const torch::Tensor& src_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& tgt_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp) { double clamp) {
Options options; Options options;
options.batchSize_ = src_lengths.size(0); options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1); options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2); options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3); options.numTargets_ = logits.size(3);
...@@ -58,8 +58,8 @@ torch::Tensor compute_alphas( ...@@ -58,8 +58,8 @@ torch::Tensor compute_alphas(
/*workspace=*/workspace, /*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(), /*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(), /*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>()); /*alphas=*/alphas.data_ptr<float>());
return alphas; return alphas;
} }
......
...@@ -9,13 +9,13 @@ namespace gpu { ...@@ -9,13 +9,13 @@ namespace gpu {
torch::Tensor compute_betas( torch::Tensor compute_betas(
const torch::Tensor& logits, const torch::Tensor& logits,
const torch::Tensor& targets, const torch::Tensor& targets,
const torch::Tensor& src_lengths, const torch::Tensor& logit_lengths,
const torch::Tensor& tgt_lengths, const torch::Tensor& target_lengths,
int64_t blank, int64_t blank,
double clamp) { double clamp) {
Options options; Options options;
options.batchSize_ = src_lengths.size(0); options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1); options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2); options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3); options.numTargets_ = logits.size(3);
...@@ -28,7 +28,7 @@ torch::Tensor compute_betas( ...@@ -28,7 +28,7 @@ torch::Tensor compute_betas(
options.device_ = GPU; options.device_ = GPU;
torch::Tensor costs = torch::empty( torch::Tensor costs = torch::empty(
tgt_lengths.size(0), target_lengths.size(0),
torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor betas = torch::zeros( torch::Tensor betas = torch::zeros(
...@@ -62,8 +62,8 @@ torch::Tensor compute_betas( ...@@ -62,8 +62,8 @@ torch::Tensor compute_betas(
/*workspace=*/workspace, /*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(), /*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(), /*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(), /*costs=*/costs.data_ptr<float>(),
/*betas=*/betas.data_ptr<float>()); /*betas=*/betas.data_ptr<float>());
return betas; return betas;
......
...@@ -51,11 +51,11 @@ def rnnt_loss( ...@@ -51,11 +51,11 @@ def rnnt_loss(
costs, gradients = torch.ops.torchaudio.rnnt_loss( costs, gradients = torch.ops.torchaudio.rnnt_loss(
logits=logits, logits=logits,
targets=targets, targets=targets,
src_lengths=logit_lengths, logit_lengths=logit_lengths,
tgt_lengths=target_lengths, target_lengths=target_lengths,
blank=blank, blank=blank,
clamp=clamp, clamp=clamp,
fused_log_smax=fused_log_softmax, fused_log_softmax=fused_log_softmax,
reuse_logits_for_grads=reuse_logits_for_grads,) reuse_logits_for_grads=reuse_logits_for_grads,)
return costs return costs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment