Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
32f661f0
"examples/vscode:/vscode.git/clone" did not exist on "c97b709afa43c2a1b90bd3429ef113fd5848d675"
Unverified
Commit
32f661f0
authored
May 06, 2021
by
Caroline Chen
Committed by
GitHub
May 06, 2021
Browse files
Migrate RNNTL input checks to C++ (#1494)
parent
723e9a52
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
112 additions
and
104 deletions
+112
-104
torchaudio/csrc/rnnt/cpu/compute.cpp
torchaudio/csrc/rnnt/cpu/compute.cpp
+56
-0
torchaudio/csrc/rnnt/gpu/compute.cu
torchaudio/csrc/rnnt/gpu/compute.cu
+56
-0
torchaudio/prototype/rnnt_loss.py
torchaudio/prototype/rnnt_loss.py
+0
-104
No files found.
torchaudio/csrc/rnnt/cpu/compute.cpp
View file @
32f661f0
...
@@ -15,6 +15,62 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -15,6 +15,62 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
double
clamp
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
fused_log_smax
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
src_lengths
.
device
().
type
(),
"logits and logit_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
tgt_lengths
.
device
().
type
(),
"logits and target_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
dtype
()
==
torch
::
kFloat32
||
logits
.
dtype
()
==
torch
::
kFloat16
,
"logits must be float32 or float16 (half) type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
src_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
TORCH_CHECK
(
tgt_lengths
.
dtype
()
==
torch
::
kInt32
,
"target_lengths must be int32 type"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
src_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
tgt_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
TORCH_CHECK
(
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
TORCH_CHECK
(
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
TORCH_CHECK
(
src_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
tgt_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
src_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and logit_lengths"
);
TORCH_CHECK
(
tgt_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and target_lengths"
);
TORCH_CHECK
(
targets
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and targets"
);
TORCH_CHECK
(
blank
>=
0
&&
blank
<
logits
.
size
(
-
1
),
"blank must be within [0, logits.shape[-1])"
);
TORCH_CHECK
(
logits
.
size
(
1
)
==
at
::
max
(
src_lengths
).
item
().
toInt
(),
"input length mismatch"
);
TORCH_CHECK
(
logits
.
size
(
2
)
==
at
::
max
(
tgt_lengths
).
item
().
toInt
()
+
1
,
"output length mismatch"
);
TORCH_CHECK
(
targets
.
size
(
1
)
==
at
::
max
(
tgt_lengths
).
item
().
toInt
(),
"target length mismatch"
);
Options
options
;
Options
options
;
options
.
batchSize_
=
src_lengths
.
size
(
0
);
options
.
batchSize_
=
src_lengths
.
size
(
0
);
options
.
nHypos_
=
tgt_lengths
.
size
(
0
)
/
src_lengths
.
size
(
0
);
options
.
nHypos_
=
tgt_lengths
.
size
(
0
)
/
src_lengths
.
size
(
0
);
...
...
torchaudio/csrc/rnnt/gpu/compute.cu
View file @
32f661f0
...
@@ -16,6 +16,62 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -16,6 +16,62 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
double
clamp
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
fused_log_smax
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
src_lengths
.
device
().
type
(),
"logits and logit_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
tgt_lengths
.
device
().
type
(),
"logits and target_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
dtype
()
==
torch
::
kFloat32
||
logits
.
dtype
()
==
torch
::
kFloat16
,
"logits must be float32 or float16 (half) type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
src_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
TORCH_CHECK
(
tgt_lengths
.
dtype
()
==
torch
::
kInt32
,
"target_lengths must be int32 type"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
src_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
tgt_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
TORCH_CHECK
(
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
TORCH_CHECK
(
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
TORCH_CHECK
(
src_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
tgt_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
src_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and logit_lengths"
);
TORCH_CHECK
(
tgt_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and target_lengths"
);
TORCH_CHECK
(
targets
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and targets"
);
TORCH_CHECK
(
blank
>=
0
&&
blank
<
logits
.
size
(
-
1
),
"blank must be within [0, logits.shape[-1])"
);
TORCH_CHECK
(
logits
.
size
(
1
)
==
at
::
max
(
src_lengths
).
item
().
toInt
(),
"input length mismatch"
);
TORCH_CHECK
(
logits
.
size
(
2
)
==
at
::
max
(
tgt_lengths
).
item
().
toInt
()
+
1
,
"output length mismatch"
);
TORCH_CHECK
(
targets
.
size
(
1
)
==
at
::
max
(
tgt_lengths
).
item
().
toInt
(),
"target length mismatch"
);
Options
options
;
Options
options
;
options
.
batchSize_
=
src_lengths
.
size
(
0
);
options
.
batchSize_
=
src_lengths
.
size
(
0
);
options
.
nHypos_
=
tgt_lengths
.
size
(
0
)
/
src_lengths
.
size
(
0
);
options
.
nHypos_
=
tgt_lengths
.
size
(
0
)
/
src_lengths
.
size
(
0
);
...
...
torchaudio/prototype/rnnt_loss.py
View file @
32f661f0
...
@@ -80,7 +80,6 @@ class _RNNT(torch.autograd.Function):
...
@@ -80,7 +80,6 @@ class _RNNT(torch.autograd.Function):
target_lengths
,
target_lengths
,
blank
=-
1
,
blank
=-
1
,
clamp
=-
1
,
clamp
=-
1
,
runtime_check
=
False
,
fused_log_softmax
=
True
,
fused_log_softmax
=
True
,
reuse_logits_for_grads
=
True
,
reuse_logits_for_grads
=
True
,
):
):
...
@@ -101,15 +100,6 @@ class _RNNT(torch.autograd.Function):
...
@@ -101,15 +100,6 @@ class _RNNT(torch.autograd.Function):
if
blank
<
0
:
# reinterpret blank index if blank < 0.
if
blank
<
0
:
# reinterpret blank index if blank < 0.
blank
=
logits
.
shape
[
-
1
]
+
blank
blank
=
logits
.
shape
[
-
1
]
+
blank
if
runtime_check
:
check_inputs
(
logits
=
logits
,
targets
=
targets
,
logit_lengths
=
logit_lengths
,
target_lengths
=
target_lengths
,
blank
=
blank
,
)
costs
,
gradients
=
torch
.
ops
.
torchaudio
.
rnnt_loss
(
costs
,
gradients
=
torch
.
ops
.
torchaudio
.
rnnt_loss
(
logits
=
logits
,
logits
=
logits
,
targets
=
targets
,
targets
=
targets
,
...
@@ -137,7 +127,6 @@ class _RNNT(torch.autograd.Function):
...
@@ -137,7 +127,6 @@ class _RNNT(torch.autograd.Function):
None
,
# target_lengths
None
,
# target_lengths
None
,
# blank
None
,
# blank
None
,
# clamp
None
,
# clamp
None
,
# runtime_check
None
,
# fused_log_softmax
None
,
# fused_log_softmax
None
,
# reuse_logits_for_grads
None
,
# reuse_logits_for_grads
)
)
...
@@ -150,7 +139,6 @@ def rnnt_loss(
...
@@ -150,7 +139,6 @@ def rnnt_loss(
target_lengths
,
target_lengths
,
blank
=-
1
,
blank
=-
1
,
clamp
=-
1
,
clamp
=-
1
,
runtime_check
=
False
,
fused_log_softmax
=
True
,
fused_log_softmax
=
True
,
reuse_logits_for_grads
=
True
,
reuse_logits_for_grads
=
True
,
):
):
...
@@ -185,7 +173,6 @@ def rnnt_loss(
...
@@ -185,7 +173,6 @@ def rnnt_loss(
target_lengths
,
target_lengths
,
blank
,
blank
,
clamp
,
clamp
,
runtime_check
,
fused_log_softmax
,
fused_log_softmax
,
reuse_logits_for_grads
,
reuse_logits_for_grads
,
)
)
...
@@ -203,7 +190,6 @@ class RNNTLoss(torch.nn.Module):
...
@@ -203,7 +190,6 @@ class RNNTLoss(torch.nn.Module):
Args:
Args:
blank (int, opt): blank label (Default: ``-1``)
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
"""
"""
...
@@ -212,14 +198,12 @@ class RNNTLoss(torch.nn.Module):
...
@@ -212,14 +198,12 @@ class RNNTLoss(torch.nn.Module):
self
,
self
,
blank
=-
1
,
blank
=-
1
,
clamp
=-
1
,
clamp
=-
1
,
runtime_check
=
False
,
fused_log_softmax
=
True
,
fused_log_softmax
=
True
,
reuse_logits_for_grads
=
True
,
reuse_logits_for_grads
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
blank
=
blank
self
.
blank
=
blank
self
.
clamp
=
clamp
self
.
clamp
=
clamp
self
.
runtime_check
=
runtime_check
self
.
fused_log_softmax
=
fused_log_softmax
self
.
fused_log_softmax
=
fused_log_softmax
self
.
reuse_logits_for_grads
=
reuse_logits_for_grads
self
.
reuse_logits_for_grads
=
reuse_logits_for_grads
...
@@ -244,94 +228,6 @@ class RNNTLoss(torch.nn.Module):
...
@@ -244,94 +228,6 @@ class RNNTLoss(torch.nn.Module):
target_lengths
,
target_lengths
,
self
.
blank
,
self
.
blank
,
self
.
clamp
,
self
.
clamp
,
self
.
runtime_check
,
self
.
fused_log_softmax
,
self
.
fused_log_softmax
,
self
.
reuse_logits_for_grads
,
self
.
reuse_logits_for_grads
,
)
)
def
check_type
(
var
,
t
,
name
):
if
var
.
dtype
is
not
t
:
raise
TypeError
(
"{} must be {}"
.
format
(
name
,
t
))
def
check_contiguous
(
var
,
name
):
if
not
var
.
is_contiguous
():
raise
ValueError
(
"{} must be contiguous"
.
format
(
name
))
def
check_dim
(
var
,
dim
,
name
):
if
len
(
var
.
shape
)
!=
dim
:
raise
ValueError
(
"{} must be {}D"
.
format
(
name
,
dim
))
def
check_equal
(
var1
,
name1
,
var2
,
name2
):
if
var1
!=
var2
:
raise
ValueError
(
"`{}` ({}) must equal to "
.
format
(
name1
,
var1
)
+
"`{}` ({})"
.
format
(
name2
,
var2
)
)
def
check_device
(
var1
,
name1
,
var2
,
name2
):
if
var1
.
device
!=
var2
.
device
:
raise
ValueError
(
"`{}` ({}) must be on the same "
.
format
(
name1
,
var1
.
device
.
type
)
+
"device as `{}` ({})"
.
format
(
name2
,
var2
.
device
.
type
)
)
def
check_inputs
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
):
check_device
(
logits
,
"logits"
,
targets
,
"targets"
)
check_device
(
logits
,
"logits"
,
targets
,
"logit_lengths"
)
check_device
(
logits
,
"logits"
,
targets
,
"target_lengths"
)
check_type
(
logits
,
torch
.
float32
,
"logits"
)
check_type
(
targets
,
torch
.
int32
,
"targets"
)
check_type
(
logit_lengths
,
torch
.
int32
,
"logit_lengths"
)
check_type
(
target_lengths
,
torch
.
int32
,
"target_lengths"
)
check_contiguous
(
logits
,
"logits"
)
check_contiguous
(
targets
,
"targets"
)
check_contiguous
(
target_lengths
,
"target_lengths"
)
check_contiguous
(
logit_lengths
,
"logit_lengths"
)
check_dim
(
logits
,
4
,
"logits"
)
check_dim
(
targets
,
2
,
"targets"
)
check_dim
(
logit_lengths
,
1
,
"logit_lengths"
)
check_dim
(
target_lengths
,
1
,
"target_lengths"
)
check_equal
(
logit_lengths
.
shape
[
0
],
"logit_lengths.shape[0]"
,
logits
.
shape
[
0
],
"logits.shape[0]"
)
check_equal
(
target_lengths
.
shape
[
0
],
"target_lengths.shape[0]"
,
logits
.
shape
[
0
],
"logits.shape[0]"
)
check_equal
(
targets
.
shape
[
0
],
"targets.shape[0]"
,
logits
.
shape
[
0
],
"logits.shape[0]"
)
check_equal
(
targets
.
shape
[
1
],
"targets.shape[1]"
,
torch
.
max
(
target_lengths
),
"torch.max(target_lengths)"
,
)
check_equal
(
logits
.
shape
[
1
],
"logits.shape[1]"
,
torch
.
max
(
logit_lengths
),
"torch.max(logit_lengths)"
,
)
check_equal
(
logits
.
shape
[
2
],
"logits.shape[2]"
,
torch
.
max
(
target_lengths
)
+
1
,
"torch.max(target_lengths) + 1"
,
)
if
blank
<
0
or
blank
>=
logits
.
shape
[
-
1
]:
raise
ValueError
(
"blank ({}) must be within [0, logits.shape[-1]={})"
.
format
(
blank
,
logits
.
shape
[
-
1
]
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment