Unverified Commit a21b08e3 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Remove unused RNNTL functions (#1518)

parent af7eb4d6
#pragma once
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace torchaudio {
namespace rnnt {
template <typename DTYPE, typename CAST_DTYPE>
status_t Compute(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* gradients = nullptr) {
switch (workspace.GetOptions().device_) {
case CPU: {
status_t status = cpu::Compute<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*gradients=*/gradients);
return status;
}
case GPU: {
status_t status = gpu::Compute<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*gradients=*/gradients);
return status;
}
default: {
return FAILURE;
}
};
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeAlphas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* alphas) {
switch (workspace.GetOptions().device_) {
case CPU: {
status_t status = cpu::ComputeAlphas<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alphas=*/alphas);
return status;
}
case GPU: {
status_t status = gpu::ComputeAlphas<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/alphas);
return status;
}
default: {
return FAILURE;
}
};
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeBetas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* betas) {
switch (workspace.GetOptions().device_) {
case CPU: {
status_t status = cpu::ComputeBetas<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*betas=*/betas);
return status;
}
case GPU: {
status_t status = gpu::ComputeBetas<DTYPE, CAST_DTYPE>(
/*workspace=*/workspace,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*betas=*/betas);
return status;
}
default: {
return FAILURE;
}
};
}
} // namespace rnnt
} // namespace torchaudio
......@@ -7,52 +7,6 @@ __all__ = [
]
def _rnnt_loss_alphas(
logits,
targets,
logit_lengths,
target_lengths,
blank=-1,
clamp=-1,
):
"""
Compute alphas for RNN transducer loss.
See documentation for RNNTLoss
"""
return torch.ops.torchaudio.rnnt_loss_alphas(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
)
def _rnnt_loss_betas(
logits,
targets,
logit_lengths,
target_lengths,
blank=-1,
clamp=-1,
):
"""
Compute betas for RNN transducer loss
See documentation for RNNTLoss
"""
return torch.ops.torchaudio.rnnt_loss_betas(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
)
def rnnt_loss(
logits: Tensor,
targets: Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment