Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
f06074aa
Unverified
Commit
f06074aa
authored
Mar 16, 2021
by
Caroline Chen
Committed by
GitHub
Mar 16, 2021
Browse files
Migrate transducer input checks to C++ (#1391)
parent
1ebfb3de
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
43 deletions
+29
-43
torchaudio/csrc/transducer.cpp
torchaudio/csrc/transducer.cpp
+29
-0
torchaudio/prototype/transducer.py
torchaudio/prototype/transducer.py
+0
-43
No files found.
torchaudio/csrc/transducer.cpp
View file @
f06074aa
...
@@ -17,11 +17,40 @@ int64_t cpu_rnnt_loss(
...
@@ -17,11 +17,40 @@ int64_t cpu_rnnt_loss(
torch
::
Tensor
grads
,
torch
::
Tensor
grads
,
int64_t
blank_label
,
int64_t
blank_label
,
int64_t
num_threads
)
{
int64_t
num_threads
)
{
TORCH_CHECK
(
labels
.
dtype
()
==
torch
::
kInt32
,
"labels must be int32 type"
);
TORCH_CHECK
(
label_lengths
.
dtype
()
==
torch
::
kInt32
,
"label_lengths must be int32 type"
);
TORCH_CHECK
(
input_lengths
.
dtype
()
==
torch
::
kInt32
,
"lengths must be int32 type"
);
TORCH_CHECK
(
acts
.
is_contiguous
(),
"acts must be contiguous"
);
TORCH_CHECK
(
labels
.
is_contiguous
(),
"labels must be contiguous"
);
TORCH_CHECK
(
label_lengths
.
is_contiguous
(),
"label_lengths must be contiguous"
);
TORCH_CHECK
(
input_lengths
.
is_contiguous
(),
"lengths must be contiguous"
);
TORCH_CHECK
(
input_lengths
.
size
(
0
)
==
acts
.
size
(
0
),
"batch dimension mismatch between acts and input_lengths: each example must have a length"
);
TORCH_CHECK
(
label_lengths
.
size
(
0
)
==
acts
.
size
(
0
),
"batch dimension mismatch between acts and label_lengths: each example must have a label length"
);
TORCH_CHECK
(
acts
.
dim
()
==
4
,
"acts must be 4-D (batch, time, label, class)"
);
TORCH_CHECK
(
labels
.
dim
()
==
2
,
"labels must be 2-D (batch, max label length)"
);
TORCH_CHECK
(
input_lengths
.
dim
()
==
1
,
"input_lengths must be 1-D"
);
TORCH_CHECK
(
label_lengths
.
dim
()
==
1
,
"label_lengths must be 1-D"
);
int
maxT
=
acts
.
size
(
1
);
int
maxT
=
acts
.
size
(
1
);
int
maxU
=
acts
.
size
(
2
);
int
maxU
=
acts
.
size
(
2
);
int
minibatch_size
=
acts
.
size
(
0
);
int
minibatch_size
=
acts
.
size
(
0
);
int
alphabet_size
=
acts
.
size
(
3
);
int
alphabet_size
=
acts
.
size
(
3
);
TORCH_CHECK
(
at
::
max
(
input_lengths
).
item
().
toInt
()
==
maxT
,
"input length mismatch"
);
TORCH_CHECK
(
at
::
max
(
label_lengths
).
item
().
toInt
()
+
1
==
maxU
,
"output length mismatch"
);
rnntOptions
options
;
rnntOptions
options
;
memset
(
&
options
,
0
,
sizeof
(
options
));
memset
(
&
options
,
0
,
sizeof
(
options
));
options
.
maxT
=
maxT
;
options
.
maxT
=
maxT
;
...
...
torchaudio/prototype/transducer.py
View file @
f06074aa
...
@@ -19,7 +19,6 @@ class _RNNT(Function):
...
@@ -19,7 +19,6 @@ class _RNNT(Function):
"""
"""
device
=
acts
.
device
device
=
acts
.
device
check_inputs
(
acts
,
labels
,
act_lens
,
label_lens
)
acts
=
acts
.
to
(
"cpu"
)
acts
=
acts
.
to
(
"cpu"
)
labels
=
labels
.
to
(
"cpu"
)
labels
=
labels
.
to
(
"cpu"
)
...
@@ -118,45 +117,3 @@ class RNNTLoss(Module):
...
@@ -118,45 +117,3 @@ class RNNTLoss(Module):
# log_softmax is computed within GPU version.
# log_softmax is computed within GPU version.
acts
=
torch
.
nn
.
functional
.
log_softmax
(
acts
,
-
1
)
acts
=
torch
.
nn
.
functional
.
log_softmax
(
acts
,
-
1
)
return
self
.
loss
(
acts
,
labels
,
act_lens
,
label_lens
,
self
.
blank
,
self
.
reduction
)
return
self
.
loss
(
acts
,
labels
,
act_lens
,
label_lens
,
self
.
blank
,
self
.
reduction
)
def
check_type
(
var
,
t
,
name
):
if
var
.
dtype
is
not
t
:
raise
TypeError
(
"{} must be {}"
.
format
(
name
,
t
))
def
check_contiguous
(
var
,
name
):
if
not
var
.
is_contiguous
():
raise
ValueError
(
"{} must be contiguous"
.
format
(
name
))
def
check_dim
(
var
,
dim
,
name
):
if
len
(
var
.
shape
)
!=
dim
:
raise
ValueError
(
"{} must be {}D"
.
format
(
name
,
dim
))
def
check_inputs
(
log_probs
,
labels
,
lengths
,
label_lengths
):
check_type
(
labels
,
torch
.
int32
,
"labels"
)
check_type
(
label_lengths
,
torch
.
int32
,
"label_lengths"
)
check_type
(
lengths
,
torch
.
int32
,
"lengths"
)
check_contiguous
(
log_probs
,
"log_probs"
)
check_contiguous
(
labels
,
"labels"
)
check_contiguous
(
label_lengths
,
"label_lengths"
)
check_contiguous
(
lengths
,
"lengths"
)
if
lengths
.
shape
[
0
]
!=
log_probs
.
shape
[
0
]:
raise
ValueError
(
"must have a length per example."
)
if
label_lengths
.
shape
[
0
]
!=
log_probs
.
shape
[
0
]:
raise
ValueError
(
"must have a label length per example."
)
check_dim
(
log_probs
,
4
,
"log_probs"
)
check_dim
(
labels
,
2
,
"labels"
)
check_dim
(
lengths
,
1
,
"lengths"
)
check_dim
(
label_lengths
,
1
,
"label_lengths"
)
max_T
=
torch
.
max
(
lengths
)
max_U
=
torch
.
max
(
label_lengths
)
T
,
U
=
log_probs
.
shape
[
1
:
3
]
if
T
!=
max_T
:
raise
ValueError
(
"Input length mismatch"
)
if
U
!=
max_U
+
1
:
raise
ValueError
(
"Output length mismatch"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment