Unverified Commit 32f661f0 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Migrate RNNTL input checks to C++ (#1494)

parent 723e9a52
......@@ -15,6 +15,62 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
TORCH_CHECK(
logits.device().type() == src_lengths.device().type(),
"logits and logit_lengths must be on the same device");
TORCH_CHECK(
logits.device().type() == tgt_lengths.device().type(),
"logits and target_lengths must be on the same device");
TORCH_CHECK(
logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
"logits must be float32 or float16 (half) type");
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
TORCH_CHECK(
src_lengths.dtype() == torch::kInt32, "logit_lengths must be int32 type");
TORCH_CHECK(
tgt_lengths.dtype() == torch::kInt32,
"target_lengths must be int32 type");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(src_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(tgt_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
TORCH_CHECK(src_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(tgt_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK(
src_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths");
TORCH_CHECK(
tgt_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths");
TORCH_CHECK(
targets.size(0) == logits.size(0),
"batch dimension mismatch between logits and targets");
TORCH_CHECK(
blank >= 0 && blank < logits.size(-1),
"blank must be within [0, logits.shape[-1])");
TORCH_CHECK(
logits.size(1) == at::max(src_lengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
logits.size(2) == at::max(tgt_lengths).item().toInt() + 1,
"output length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(tgt_lengths).item().toInt(),
"target length mismatch");
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
......
......@@ -16,6 +16,62 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
TORCH_CHECK(
logits.device().type() == src_lengths.device().type(),
"logits and logit_lengths must be on the same device");
TORCH_CHECK(
logits.device().type() == tgt_lengths.device().type(),
"logits and target_lengths must be on the same device");
TORCH_CHECK(
logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
"logits must be float32 or float16 (half) type");
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
TORCH_CHECK(
src_lengths.dtype() == torch::kInt32, "logit_lengths must be int32 type");
TORCH_CHECK(
tgt_lengths.dtype() == torch::kInt32,
"target_lengths must be int32 type");
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(src_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(tgt_lengths.is_contiguous(), "target_lengths must be contiguous");
TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
TORCH_CHECK(src_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(tgt_lengths.dim() == 1, "target_lengths must be 1-D");
TORCH_CHECK(
src_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths");
TORCH_CHECK(
tgt_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths");
TORCH_CHECK(
targets.size(0) == logits.size(0),
"batch dimension mismatch between logits and targets");
TORCH_CHECK(
blank >= 0 && blank < logits.size(-1),
"blank must be within [0, logits.shape[-1])");
TORCH_CHECK(
logits.size(1) == at::max(src_lengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
logits.size(2) == at::max(tgt_lengths).item().toInt() + 1,
"output length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(tgt_lengths).item().toInt(),
"target length mismatch");
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
......
......@@ -80,7 +80,6 @@ class _RNNT(torch.autograd.Function):
target_lengths,
blank=-1,
clamp=-1,
runtime_check=False,
fused_log_softmax=True,
reuse_logits_for_grads=True,
):
......@@ -101,15 +100,6 @@ class _RNNT(torch.autograd.Function):
if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank
if runtime_check:
check_inputs(
logits=logits,
targets=targets,
logit_lengths=logit_lengths,
target_lengths=target_lengths,
blank=blank,
)
costs, gradients = torch.ops.torchaudio.rnnt_loss(
logits=logits,
targets=targets,
......@@ -137,7 +127,6 @@ class _RNNT(torch.autograd.Function):
None, # target_lengths
None, # blank
None, # clamp
None, # runtime_check
None, # fused_log_softmax
None, # reuse_logits_for_grads
)
......@@ -150,7 +139,6 @@ def rnnt_loss(
target_lengths,
blank=-1,
clamp=-1,
runtime_check=False,
fused_log_softmax=True,
reuse_logits_for_grads=True,
):
......@@ -185,7 +173,6 @@ def rnnt_loss(
target_lengths,
blank,
clamp,
runtime_check,
fused_log_softmax,
reuse_logits_for_grads,
)
......@@ -203,7 +190,6 @@ class RNNTLoss(torch.nn.Module):
Args:
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
"""
......@@ -212,14 +198,12 @@ class RNNTLoss(torch.nn.Module):
self,
blank=-1,
clamp=-1,
runtime_check=False,
fused_log_softmax=True,
reuse_logits_for_grads=True,
):
super().__init__()
self.blank = blank
self.clamp = clamp
self.runtime_check = runtime_check
self.fused_log_softmax = fused_log_softmax
self.reuse_logits_for_grads = reuse_logits_for_grads
......@@ -244,94 +228,6 @@ class RNNTLoss(torch.nn.Module):
target_lengths,
self.blank,
self.clamp,
self.runtime_check,
self.fused_log_softmax,
self.reuse_logits_for_grads,
)
def check_type(var, t, name):
if var.dtype is not t:
raise TypeError("{} must be {}".format(name, t))
def check_contiguous(var, name):
if not var.is_contiguous():
raise ValueError("{} must be contiguous".format(name))
def check_dim(var, dim, name):
if len(var.shape) != dim:
raise ValueError("{} must be {}D".format(name, dim))
def check_equal(var1, name1, var2, name2):
if var1 != var2:
raise ValueError(
"`{}` ({}) must equal to ".format(name1, var1)
+ "`{}` ({})".format(name2, var2)
)
def check_device(var1, name1, var2, name2):
if var1.device != var2.device:
raise ValueError(
"`{}` ({}) must be on the same ".format(name1, var1.device.type)
+ "device as `{}` ({})".format(name2, var2.device.type)
)
def check_inputs(logits, targets, logit_lengths, target_lengths, blank):
check_device(logits, "logits", targets, "targets")
check_device(logits, "logits", targets, "logit_lengths")
check_device(logits, "logits", targets, "target_lengths")
check_type(logits, torch.float32, "logits")
check_type(targets, torch.int32, "targets")
check_type(logit_lengths, torch.int32, "logit_lengths")
check_type(target_lengths, torch.int32, "target_lengths")
check_contiguous(logits, "logits")
check_contiguous(targets, "targets")
check_contiguous(target_lengths, "target_lengths")
check_contiguous(logit_lengths, "logit_lengths")
check_dim(logits, 4, "logits")
check_dim(targets, 2, "targets")
check_dim(logit_lengths, 1, "logit_lengths")
check_dim(target_lengths, 1, "target_lengths")
check_equal(
logit_lengths.shape[0], "logit_lengths.shape[0]", logits.shape[0], "logits.shape[0]"
)
check_equal(
target_lengths.shape[0], "target_lengths.shape[0]", logits.shape[0], "logits.shape[0]"
)
check_equal(
targets.shape[0], "targets.shape[0]", logits.shape[0], "logits.shape[0]"
)
check_equal(
targets.shape[1],
"targets.shape[1]",
torch.max(target_lengths),
"torch.max(target_lengths)",
)
check_equal(
logits.shape[1],
"logits.shape[1]",
torch.max(logit_lengths),
"torch.max(logit_lengths)",
)
check_equal(
logits.shape[2],
"logits.shape[2]",
torch.max(target_lengths) + 1,
"torch.max(target_lengths) + 1",
)
if blank < 0 or blank >= logits.shape[-1]:
raise ValueError(
"blank ({}) must be within [0, logits.shape[-1]={})".format(
blank, logits.shape[-1]
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment