Unverified Commit f06074aa authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Migrate transducer input checks to C++ (#1391)

parent 1ebfb3de
...@@ -17,11 +17,40 @@ int64_t cpu_rnnt_loss( ...@@ -17,11 +17,40 @@ int64_t cpu_rnnt_loss(
torch::Tensor grads, torch::Tensor grads,
int64_t blank_label, int64_t blank_label,
int64_t num_threads) { int64_t num_threads) {
TORCH_CHECK(labels.dtype() == torch::kInt32, "labels must be int32 type");
TORCH_CHECK(
label_lengths.dtype() == torch::kInt32,
"label_lengths must be int32 type");
TORCH_CHECK(
input_lengths.dtype() == torch::kInt32, "lengths must be int32 type");
TORCH_CHECK(acts.is_contiguous(), "acts must be contiguous");
TORCH_CHECK(labels.is_contiguous(), "labels must be contiguous");
TORCH_CHECK(
label_lengths.is_contiguous(), "label_lengths must be contiguous");
TORCH_CHECK(input_lengths.is_contiguous(), "lengths must be contiguous");
TORCH_CHECK(
input_lengths.size(0) == acts.size(0),
"batch dimension mismatch between acts and input_lengths: each example must have a length");
TORCH_CHECK(
label_lengths.size(0) == acts.size(0),
"batch dimension mismatch between acts and label_lengths: each example must have a label length");
TORCH_CHECK(acts.dim() == 4, "acts must be 4-D (batch, time, label, class)");
TORCH_CHECK(
labels.dim() == 2, "labels must be 2-D (batch, max label length)");
TORCH_CHECK(input_lengths.dim() == 1, "input_lengths must be 1-D");
TORCH_CHECK(label_lengths.dim() == 1, "label_lengths must be 1-D");
int maxT = acts.size(1); int maxT = acts.size(1);
int maxU = acts.size(2); int maxU = acts.size(2);
int minibatch_size = acts.size(0); int minibatch_size = acts.size(0);
int alphabet_size = acts.size(3); int alphabet_size = acts.size(3);
TORCH_CHECK(
at::max(input_lengths).item().toInt() == maxT, "input length mismatch");
TORCH_CHECK(
at::max(label_lengths).item().toInt() + 1 == maxU,
"output length mismatch");
rnntOptions options; rnntOptions options;
memset(&options, 0, sizeof(options)); memset(&options, 0, sizeof(options));
options.maxT = maxT; options.maxT = maxT;
......
...@@ -19,7 +19,6 @@ class _RNNT(Function): ...@@ -19,7 +19,6 @@ class _RNNT(Function):
""" """
device = acts.device device = acts.device
check_inputs(acts, labels, act_lens, label_lens)
acts = acts.to("cpu") acts = acts.to("cpu")
labels = labels.to("cpu") labels = labels.to("cpu")
...@@ -118,45 +117,3 @@ class RNNTLoss(Module): ...@@ -118,45 +117,3 @@ class RNNTLoss(Module):
# log_softmax is computed within GPU version. # log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1) acts = torch.nn.functional.log_softmax(acts, -1)
return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction) return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction)
def check_type(var, t, name):
if var.dtype is not t:
raise TypeError("{} must be {}".format(name, t))
def check_contiguous(var, name):
if not var.is_contiguous():
raise ValueError("{} must be contiguous".format(name))
def check_dim(var, dim, name):
if len(var.shape) != dim:
raise ValueError("{} must be {}D".format(name, dim))
def check_inputs(log_probs, labels, lengths, label_lengths):
check_type(labels, torch.int32, "labels")
check_type(label_lengths, torch.int32, "label_lengths")
check_type(lengths, torch.int32, "lengths")
check_contiguous(log_probs, "log_probs")
check_contiguous(labels, "labels")
check_contiguous(label_lengths, "label_lengths")
check_contiguous(lengths, "lengths")
if lengths.shape[0] != log_probs.shape[0]:
raise ValueError("must have a length per example.")
if label_lengths.shape[0] != log_probs.shape[0]:
raise ValueError("must have a label length per example.")
check_dim(log_probs, 4, "log_probs")
check_dim(labels, 2, "labels")
check_dim(lengths, 1, "lengths")
check_dim(label_lengths, 1, "label_lengths")
max_T = torch.max(lengths)
max_U = torch.max(label_lengths)
T, U = log_probs.shape[1:3]
if T != max_T:
raise ValueError("Input length mismatch")
if U != max_U + 1:
raise ValueError("Output length mismatch")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment