Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
bac32ec1
"vscode:/vscode.git/clone" did not exist on "60fdad7cf343333e956a3889c12956396a1516bf"
Unverified
Commit
bac32ec1
authored
Jun 24, 2021
by
Caroline Chen
Committed by
GitHub
Jun 24, 2021
Browse files
Add reduction parameter for RNNT loss (#1590)
parent
2376e9c9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
1 deletion
+17
-1
test/torchaudio_unittest/rnnt/utils.py
test/torchaudio_unittest/rnnt/utils.py
+1
-0
torchaudio/prototype/rnnt_loss.py
torchaudio/prototype/rnnt_loss.py
+16
-1
No files found.
test/torchaudio_unittest/rnnt/utils.py
View file @
bac32ec1
...
...
@@ -31,6 +31,7 @@ def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
blank
=
data
[
"blank"
],
fused_log_softmax
=
data
.
get
(
"fused_log_softmax"
,
True
),
reuse_logits_for_grads
=
reuse_logits_for_grads
,
reduction
=
"none"
,
)(
logits
=
data
[
"logits"
],
logit_lengths
=
data
[
"logit_lengths"
],
...
...
torchaudio/prototype/rnnt_loss.py
View file @
bac32ec1
...
...
@@ -16,6 +16,7 @@ def rnnt_loss(
clamp
:
float
=
-
1
,
fused_log_softmax
:
bool
=
True
,
reuse_logits_for_grads
:
bool
=
True
,
reduction
:
str
=
"mean"
,
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
[:footcite:`graves2012sequence`].
...
...
@@ -31,14 +32,18 @@ def rnnt_loss(
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
if
reduction
not
in
[
'none'
,
'mean'
,
'sum'
]:
raise
ValueError
(
"reduction should be one of 'none', 'mean', or 'sum'"
)
if
not
fused_log_softmax
:
logits
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
reuse_logits_for_grads
=
(
...
...
@@ -58,6 +63,11 @@ def rnnt_loss(
fused_log_softmax
=
fused_log_softmax
,
reuse_logits_for_grads
=
reuse_logits_for_grads
,)
if
reduction
==
'mean'
:
return
costs
.
mean
()
elif
reduction
==
'sum'
:
return
costs
.
sum
()
return
costs
...
...
@@ -74,6 +84,8 @@ class RNNTLoss(torch.nn.Module):
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
def
__init__
(
...
...
@@ -82,12 +94,14 @@ class RNNTLoss(torch.nn.Module):
clamp
:
float
=
-
1.
,
fused_log_softmax
:
bool
=
True
,
reuse_logits_for_grads
:
bool
=
True
,
reduction
:
str
=
"mean"
,
):
super
().
__init__
()
self
.
blank
=
blank
self
.
clamp
=
clamp
self
.
fused_log_softmax
=
fused_log_softmax
self
.
reuse_logits_for_grads
=
reuse_logits_for_grads
self
.
reduction
=
reduction
def
forward
(
self
,
...
...
@@ -116,4 +130,5 @@ class RNNTLoss(torch.nn.Module):
self
.
clamp
,
self
.
fused_log_softmax
,
self
.
reuse_logits_for_grads
,
self
.
reduction
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment