Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
32f661f0
Unverified
Commit
32f661f0
authored
May 06, 2021
by
Caroline Chen
Committed by
GitHub
May 06, 2021
Browse files
Migrate RNNTL input checks to C++ (#1494)
parent
723e9a52
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
112 additions
and
104 deletions
+112
-104
torchaudio/csrc/rnnt/cpu/compute.cpp
torchaudio/csrc/rnnt/cpu/compute.cpp
+56
-0
torchaudio/csrc/rnnt/gpu/compute.cu
torchaudio/csrc/rnnt/gpu/compute.cu
+56
-0
torchaudio/prototype/rnnt_loss.py
torchaudio/prototype/rnnt_loss.py
+0
-104
No files found.
torchaudio/csrc/rnnt/cpu/compute.cpp
View file @
32f661f0
...
...
@@ -15,6 +15,62 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
double
clamp
,
bool
fused_log_smax
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
src_lengths
.
device
().
type
(),
"logits and logit_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
tgt_lengths
.
device
().
type
(),
"logits and target_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
dtype
()
==
torch
::
kFloat32
||
logits
.
dtype
()
==
torch
::
kFloat16
,
"logits must be float32 or float16 (half) type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
src_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
TORCH_CHECK
(
tgt_lengths
.
dtype
()
==
torch
::
kInt32
,
"target_lengths must be int32 type"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
src_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
tgt_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
TORCH_CHECK
(
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
TORCH_CHECK
(
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
TORCH_CHECK
(
src_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
tgt_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
src_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and logit_lengths"
);
TORCH_CHECK
(
tgt_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and target_lengths"
);
TORCH_CHECK
(
targets
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and targets"
);
TORCH_CHECK
(
blank
>=
0
&&
blank
<
logits
.
size
(
-
1
),
"blank must be within [0, logits.shape[-1])"
);
TORCH_CHECK
(
logits
.
size
(
1
)
==
at
::
max
(
src_lengths
).
item
().
toInt
(),
"input length mismatch"
);
TORCH_CHECK
(
logits
.
size
(
2
)
==
at
::
max
(
tgt_lengths
).
item
().
toInt
()
+
1
,
"output length mismatch"
);
TORCH_CHECK
(
targets
.
size
(
1
)
==
at
::
max
(
tgt_lengths
).
item
().
toInt
(),
"target length mismatch"
);
Options
options
;
options
.
batchSize_
=
src_lengths
.
size
(
0
);
options
.
nHypos_
=
tgt_lengths
.
size
(
0
)
/
src_lengths
.
size
(
0
);
...
...
torchaudio/csrc/rnnt/gpu/compute.cu
View file @
32f661f0
...
...
@@ -16,6 +16,62 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
double
clamp
,
bool
fused_log_smax
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
src_lengths
.
device
().
type
(),
"logits and logit_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
tgt_lengths
.
device
().
type
(),
"logits and target_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
dtype
()
==
torch
::
kFloat32
||
logits
.
dtype
()
==
torch
::
kFloat16
,
"logits must be float32 or float16 (half) type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
src_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
TORCH_CHECK
(
tgt_lengths
.
dtype
()
==
torch
::
kInt32
,
"target_lengths must be int32 type"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
src_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
tgt_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
TORCH_CHECK
(
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
TORCH_CHECK
(
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
TORCH_CHECK
(
src_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
tgt_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
src_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and logit_lengths"
);
TORCH_CHECK
(
tgt_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and target_lengths"
);
TORCH_CHECK
(
targets
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and targets"
);
TORCH_CHECK
(
blank
>=
0
&&
blank
<
logits
.
size
(
-
1
),
"blank must be within [0, logits.shape[-1])"
);
TORCH_CHECK
(
logits
.
size
(
1
)
==
at
::
max
(
src_lengths
).
item
().
toInt
(),
"input length mismatch"
);
TORCH_CHECK
(
logits
.
size
(
2
)
==
at
::
max
(
tgt_lengths
).
item
().
toInt
()
+
1
,
"output length mismatch"
);
TORCH_CHECK
(
targets
.
size
(
1
)
==
at
::
max
(
tgt_lengths
).
item
().
toInt
(),
"target length mismatch"
);
Options
options
;
options
.
batchSize_
=
src_lengths
.
size
(
0
);
options
.
nHypos_
=
tgt_lengths
.
size
(
0
)
/
src_lengths
.
size
(
0
);
...
...
torchaudio/prototype/rnnt_loss.py
View file @
32f661f0
...
...
@@ -80,7 +80,6 @@ class _RNNT(torch.autograd.Function):
target_lengths
,
blank
=-
1
,
clamp
=-
1
,
runtime_check
=
False
,
fused_log_softmax
=
True
,
reuse_logits_for_grads
=
True
,
):
...
...
@@ -101,15 +100,6 @@ class _RNNT(torch.autograd.Function):
if
blank
<
0
:
# reinterpret blank index if blank < 0.
blank
=
logits
.
shape
[
-
1
]
+
blank
if
runtime_check
:
check_inputs
(
logits
=
logits
,
targets
=
targets
,
logit_lengths
=
logit_lengths
,
target_lengths
=
target_lengths
,
blank
=
blank
,
)
costs
,
gradients
=
torch
.
ops
.
torchaudio
.
rnnt_loss
(
logits
=
logits
,
targets
=
targets
,
...
...
@@ -137,7 +127,6 @@ class _RNNT(torch.autograd.Function):
None
,
# target_lengths
None
,
# blank
None
,
# clamp
None
,
# runtime_check
None
,
# fused_log_softmax
None
,
# reuse_logits_for_grads
)
...
...
@@ -150,7 +139,6 @@ def rnnt_loss(
target_lengths
,
blank
=-
1
,
clamp
=-
1
,
runtime_check
=
False
,
fused_log_softmax
=
True
,
reuse_logits_for_grads
=
True
,
):
...
...
@@ -185,7 +173,6 @@ def rnnt_loss(
target_lengths
,
blank
,
clamp
,
runtime_check
,
fused_log_softmax
,
reuse_logits_for_grads
,
)
...
...
@@ -203,7 +190,6 @@ class RNNTLoss(torch.nn.Module):
Args:
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
"""
...
...
@@ -212,14 +198,12 @@ class RNNTLoss(torch.nn.Module):
self
,
blank
=-
1
,
clamp
=-
1
,
runtime_check
=
False
,
fused_log_softmax
=
True
,
reuse_logits_for_grads
=
True
,
):
super
().
__init__
()
self
.
blank
=
blank
self
.
clamp
=
clamp
self
.
runtime_check
=
runtime_check
self
.
fused_log_softmax
=
fused_log_softmax
self
.
reuse_logits_for_grads
=
reuse_logits_for_grads
...
...
@@ -244,94 +228,6 @@ class RNNTLoss(torch.nn.Module):
target_lengths
,
self
.
blank
,
self
.
clamp
,
self
.
runtime_check
,
self
.
fused_log_softmax
,
self
.
reuse_logits_for_grads
,
)
def
check_type
(
var
,
t
,
name
):
if
var
.
dtype
is
not
t
:
raise
TypeError
(
"{} must be {}"
.
format
(
name
,
t
))
def
check_contiguous
(
var
,
name
):
if
not
var
.
is_contiguous
():
raise
ValueError
(
"{} must be contiguous"
.
format
(
name
))
def
check_dim
(
var
,
dim
,
name
):
if
len
(
var
.
shape
)
!=
dim
:
raise
ValueError
(
"{} must be {}D"
.
format
(
name
,
dim
))
def
check_equal
(
var1
,
name1
,
var2
,
name2
):
if
var1
!=
var2
:
raise
ValueError
(
"`{}` ({}) must equal to "
.
format
(
name1
,
var1
)
+
"`{}` ({})"
.
format
(
name2
,
var2
)
)
def
check_device
(
var1
,
name1
,
var2
,
name2
):
if
var1
.
device
!=
var2
.
device
:
raise
ValueError
(
"`{}` ({}) must be on the same "
.
format
(
name1
,
var1
.
device
.
type
)
+
"device as `{}` ({})"
.
format
(
name2
,
var2
.
device
.
type
)
)
def
check_inputs
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
):
check_device
(
logits
,
"logits"
,
targets
,
"targets"
)
check_device
(
logits
,
"logits"
,
targets
,
"logit_lengths"
)
check_device
(
logits
,
"logits"
,
targets
,
"target_lengths"
)
check_type
(
logits
,
torch
.
float32
,
"logits"
)
check_type
(
targets
,
torch
.
int32
,
"targets"
)
check_type
(
logit_lengths
,
torch
.
int32
,
"logit_lengths"
)
check_type
(
target_lengths
,
torch
.
int32
,
"target_lengths"
)
check_contiguous
(
logits
,
"logits"
)
check_contiguous
(
targets
,
"targets"
)
check_contiguous
(
target_lengths
,
"target_lengths"
)
check_contiguous
(
logit_lengths
,
"logit_lengths"
)
check_dim
(
logits
,
4
,
"logits"
)
check_dim
(
targets
,
2
,
"targets"
)
check_dim
(
logit_lengths
,
1
,
"logit_lengths"
)
check_dim
(
target_lengths
,
1
,
"target_lengths"
)
check_equal
(
logit_lengths
.
shape
[
0
],
"logit_lengths.shape[0]"
,
logits
.
shape
[
0
],
"logits.shape[0]"
)
check_equal
(
target_lengths
.
shape
[
0
],
"target_lengths.shape[0]"
,
logits
.
shape
[
0
],
"logits.shape[0]"
)
check_equal
(
targets
.
shape
[
0
],
"targets.shape[0]"
,
logits
.
shape
[
0
],
"logits.shape[0]"
)
check_equal
(
targets
.
shape
[
1
],
"targets.shape[1]"
,
torch
.
max
(
target_lengths
),
"torch.max(target_lengths)"
,
)
check_equal
(
logits
.
shape
[
1
],
"logits.shape[1]"
,
torch
.
max
(
logit_lengths
),
"torch.max(logit_lengths)"
,
)
check_equal
(
logits
.
shape
[
2
],
"logits.shape[2]"
,
torch
.
max
(
target_lengths
)
+
1
,
"torch.max(target_lengths) + 1"
,
)
if
blank
<
0
or
blank
>=
logits
.
shape
[
-
1
]:
raise
ValueError
(
"blank ({}) must be within [0, logits.shape[-1]={})"
.
format
(
blank
,
logits
.
shape
[
-
1
]
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment