Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
16f3b2f9
Unverified
Commit
16f3b2f9
authored
Aug 03, 2021
by
Caroline Chen
Committed by
GitHub
Aug 03, 2021
Browse files
Remove reuse_logits_for_grads option for RNNTL (#1610)
parent
25ceee71
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
20 additions
and
53 deletions
+20
-53
test/torchaudio_unittest/rnnt/autograd_impl.py
test/torchaudio_unittest/rnnt/autograd_impl.py
+1
-2
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
+4
-8
test/torchaudio_unittest/rnnt/utils.py
test/torchaudio_unittest/rnnt/utils.py
+1
-2
torchaudio/csrc/rnnt/autograd.cpp
torchaudio/csrc/rnnt/autograd.cpp
+4
-8
torchaudio/csrc/rnnt/compute.cpp
torchaudio/csrc/rnnt/compute.cpp
+3
-6
torchaudio/csrc/rnnt/compute.h
torchaudio/csrc/rnnt/compute.h
+1
-2
torchaudio/csrc/rnnt/cpu/compute.cpp
torchaudio/csrc/rnnt/cpu/compute.cpp
+2
-7
torchaudio/csrc/rnnt/gpu/compute.cu
torchaudio/csrc/rnnt/gpu/compute.cu
+2
-7
torchaudio/prototype/rnnt_loss.py
torchaudio/prototype/rnnt_loss.py
+2
-11
No files found.
test/torchaudio_unittest/rnnt/autograd_impl.py
View file @
16f3b2f9
...
@@ -53,7 +53,7 @@ class Autograd(TestBaseMixin):
...
@@ -53,7 +53,7 @@ class Autograd(TestBaseMixin):
data
[
"logit_lengths"
],
data
[
"logit_lengths"
],
data
[
"target_lengths"
],
data
[
"target_lengths"
],
)
)
loss
=
RNNTLoss
(
blank
=
data
[
"blank"
]
,
reuse_logits_for_grads
=
False
)
loss
=
RNNTLoss
(
blank
=
data
[
"blank"
])
self
.
assert_grad
(
loss
,
inputs
,
enable_all_grad
=
False
)
self
.
assert_grad
(
loss
,
inputs
,
enable_all_grad
=
False
)
...
@@ -72,7 +72,6 @@ class Autograd(TestBaseMixin):
...
@@ -72,7 +72,6 @@ class Autograd(TestBaseMixin):
data
[
"blank"
],
# blank
data
[
"blank"
],
# blank
-
1
,
# clamp
-
1
,
# clamp
True
,
# fused_log_softmax
True
,
# fused_log_softmax
False
,
# reuse_logits_for_grads
)
)
self
.
assert_grad
(
rnnt_loss
,
inputs
,
enable_all_grad
=
False
)
self
.
assert_grad
(
rnnt_loss
,
inputs
,
enable_all_grad
=
False
)
...
...
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
View file @
16f3b2f9
...
@@ -17,14 +17,10 @@ class RNNTLossTest:
...
@@ -17,14 +17,10 @@ class RNNTLossTest:
self
,
data
,
ref_costs
,
ref_gradients
,
atol
=
1e-6
,
rtol
=
1e-2
self
,
data
,
ref_costs
,
ref_gradients
,
atol
=
1e-6
,
rtol
=
1e-2
):
):
logits_shape
=
data
[
"logits"
].
shape
logits_shape
=
data
[
"logits"
].
shape
for
reuse_logits_for_grads
in
[
False
,
True
]:
costs
,
gradients
=
compute_with_pytorch_transducer
(
data
=
data
)
with
self
.
subTest
(
reuse_logits_for_grads
=
reuse_logits_for_grads
):
self
.
assertEqual
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
costs
,
gradients
=
compute_with_pytorch_transducer
(
self
.
assertEqual
(
logits_shape
,
gradients
.
shape
)
data
=
data
,
reuse_logits_for_grads
=
reuse_logits_for_grads
self
.
assertEqual
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
)
)
self
.
assertEqual
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
logits_shape
,
gradients
.
shape
)
self
.
assertEqual
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
)
def
test_basic_backward
(
self
):
def
test_basic_backward
(
self
):
rnnt_loss
=
RNNTLoss
()
rnnt_loss
=
RNNTLoss
()
...
...
test/torchaudio_unittest/rnnt/utils.py
View file @
16f3b2f9
...
@@ -23,11 +23,10 @@ def compute_with_numpy_transducer(data):
...
@@ -23,11 +23,10 @@ def compute_with_numpy_transducer(data):
return
costs
,
gradients
return
costs
,
gradients
def
compute_with_pytorch_transducer
(
data
,
reuse_logits_for_grads
=
False
):
def
compute_with_pytorch_transducer
(
data
):
costs
=
RNNTLoss
(
costs
=
RNNTLoss
(
blank
=
data
[
"blank"
],
blank
=
data
[
"blank"
],
fused_log_softmax
=
data
.
get
(
"fused_log_softmax"
,
True
),
fused_log_softmax
=
data
.
get
(
"fused_log_softmax"
,
True
),
reuse_logits_for_grads
=
reuse_logits_for_grads
,
reduction
=
"none"
,
reduction
=
"none"
,
)(
)(
logits
=
data
[
"logits"
],
logits
=
data
[
"logits"
],
...
...
torchaudio/csrc/rnnt/autograd.cpp
View file @
16f3b2f9
...
@@ -14,8 +14,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
...
@@ -14,8 +14,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_softmax
=
true
,
bool
fused_log_softmax
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
torch
::
Tensor
undef
;
torch
::
Tensor
undef
;
auto
result
=
rnnt_loss
(
auto
result
=
rnnt_loss
(
logits
,
logits
,
...
@@ -24,8 +23,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
...
@@ -24,8 +23,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
target_lengths
,
target_lengths
,
blank
,
blank
,
clamp
,
clamp
,
fused_log_softmax
,
fused_log_softmax
);
reuse_logits_for_grads
);
auto
costs
=
std
::
get
<
0
>
(
result
);
auto
costs
=
std
::
get
<
0
>
(
result
);
auto
grads
=
std
::
get
<
1
>
(
result
).
value_or
(
undef
);
auto
grads
=
std
::
get
<
1
>
(
result
).
value_or
(
undef
);
ctx
->
save_for_backward
({
grads
});
ctx
->
save_for_backward
({
grads
});
...
@@ -51,8 +49,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
...
@@ -51,8 +49,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_softmax
=
true
,
bool
fused_log_softmax
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
at
::
AutoDispatchBelowADInplaceOrView
guard
;
at
::
AutoDispatchBelowADInplaceOrView
guard
;
auto
results
=
RNNTLossFunction
::
apply
(
auto
results
=
RNNTLossFunction
::
apply
(
logits
,
logits
,
...
@@ -61,8 +58,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
...
@@ -61,8 +58,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
target_lengths
,
target_lengths
,
blank
,
blank
,
clamp
,
clamp
,
fused_log_softmax
,
fused_log_softmax
);
reuse_logits_for_grads
);
return
std
::
make_tuple
(
results
[
0
],
results
[
1
]);
return
std
::
make_tuple
(
results
[
0
],
results
[
1
]);
}
}
...
...
torchaudio/csrc/rnnt/compute.cpp
View file @
16f3b2f9
...
@@ -8,8 +8,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
...
@@ -8,8 +8,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_softmax
=
true
,
bool
fused_log_softmax
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
static
auto
op
=
torch
::
Dispatcher
::
singleton
()
static
auto
op
=
torch
::
Dispatcher
::
singleton
()
.
findSchemaOrThrow
(
"torchaudio::rnnt_loss"
,
""
)
.
findSchemaOrThrow
(
"torchaudio::rnnt_loss"
,
""
)
.
typed
<
decltype
(
rnnt_loss
)
>
();
.
typed
<
decltype
(
rnnt_loss
)
>
();
...
@@ -20,8 +19,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
...
@@ -20,8 +19,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
target_lengths
,
target_lengths
,
blank
,
blank
,
clamp
,
clamp
,
fused_log_softmax
,
fused_log_softmax
);
reuse_logits_for_grads
);
}
}
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
...
@@ -32,6 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -32,6 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"Tensor target_lengths,"
"Tensor target_lengths,"
"int blank,"
"int blank,"
"float clamp,"
"float clamp,"
"bool fused_log_softmax=True,"
"bool fused_log_softmax=True) -> (Tensor, Tensor?)"
);
"bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)"
);
}
}
torchaudio/csrc/rnnt/compute.h
View file @
16f3b2f9
...
@@ -9,5 +9,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
...
@@ -9,5 +9,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_softmax
,
bool
fused_log_softmax
);
bool
reuse_logits_for_grads
);
torchaudio/csrc/rnnt/cpu/compute.cpp
View file @
16f3b2f9
...
@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_softmax
=
true
,
bool
fused_log_softmax
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
"logits and targets must be on the same device"
);
...
@@ -92,11 +91,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -92,11 +91,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
c10
::
optional
<
torch
::
Tensor
>
gradients
=
c10
::
nullopt
;
c10
::
optional
<
torch
::
Tensor
>
gradients
=
c10
::
nullopt
;
if
(
logits
.
requires_grad
())
{
if
(
logits
.
requires_grad
())
{
if
(
reuse_logits_for_grads
)
{
gradients
=
torch
::
zeros_like
(
logits
);
gradients
=
logits
;
}
else
{
gradients
=
torch
::
zeros_like
(
logits
);
}
}
}
torch
::
Tensor
int_workspace
=
torch
::
empty
(
torch
::
Tensor
int_workspace
=
torch
::
empty
(
...
...
torchaudio/csrc/rnnt/gpu/compute.cu
View file @
16f3b2f9
...
@@ -14,8 +14,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -14,8 +14,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const
torch
::
Tensor
&
target_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
int64_t
blank
,
double
clamp
,
double
clamp
,
bool
fused_log_softmax
=
true
,
bool
fused_log_softmax
=
true
)
{
bool
reuse_logits_for_grads
=
true
)
{
TORCH_CHECK
(
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
"logits and targets must be on the same device"
);
...
@@ -95,11 +94,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
...
@@ -95,11 +94,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
c10
::
optional
<
torch
::
Tensor
>
gradients
=
c10
::
nullopt
;
c10
::
optional
<
torch
::
Tensor
>
gradients
=
c10
::
nullopt
;
if
(
logits
.
requires_grad
())
{
if
(
logits
.
requires_grad
())
{
if
(
reuse_logits_for_grads
)
{
gradients
=
torch
::
zeros_like
(
logits
);
gradients
=
logits
;
}
else
{
gradients
=
torch
::
zeros_like
(
logits
);
}
}
}
torch
::
Tensor
int_workspace
=
torch
::
empty
(
torch
::
Tensor
int_workspace
=
torch
::
empty
(
...
...
torchaudio/prototype/rnnt_loss.py
View file @
16f3b2f9
...
@@ -15,7 +15,6 @@ def rnnt_loss(
...
@@ -15,7 +15,6 @@ def rnnt_loss(
blank
:
int
=
-
1
,
blank
:
int
=
-
1
,
clamp
:
float
=
-
1
,
clamp
:
float
=
-
1
,
fused_log_softmax
:
bool
=
True
,
fused_log_softmax
:
bool
=
True
,
reuse_logits_for_grads
:
bool
=
True
,
reduction
:
str
=
"mean"
,
reduction
:
str
=
"mean"
,
):
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
...
@@ -33,7 +32,6 @@ def rnnt_loss(
...
@@ -33,7 +32,6 @@ def rnnt_loss(
blank (int, opt): blank label (Default: ``-1``)
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
...
@@ -46,9 +44,6 @@ def rnnt_loss(
...
@@ -46,9 +44,6 @@ def rnnt_loss(
if
not
fused_log_softmax
:
if
not
fused_log_softmax
:
logits
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
logits
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
reuse_logits_for_grads
=
(
False
# softmax needs the original logits value
)
if
blank
<
0
:
# reinterpret blank index if blank < 0.
if
blank
<
0
:
# reinterpret blank index if blank < 0.
blank
=
logits
.
shape
[
-
1
]
+
blank
blank
=
logits
.
shape
[
-
1
]
+
blank
...
@@ -60,8 +55,8 @@ def rnnt_loss(
...
@@ -60,8 +55,8 @@ def rnnt_loss(
target_lengths
=
target_lengths
,
target_lengths
=
target_lengths
,
blank
=
blank
,
blank
=
blank
,
clamp
=
clamp
,
clamp
=
clamp
,
fused_log_softmax
=
fused_log_softmax
,
fused_log_softmax
=
fused_log_softmax
reuse_logits_for_grads
=
reuse_logits_for_grads
,
)
)
if
reduction
==
'mean'
:
if
reduction
==
'mean'
:
return
costs
.
mean
()
return
costs
.
mean
()
...
@@ -83,7 +78,6 @@ class RNNTLoss(torch.nn.Module):
...
@@ -83,7 +78,6 @@ class RNNTLoss(torch.nn.Module):
blank (int, opt): blank label (Default: ``-1``)
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
"""
...
@@ -93,14 +87,12 @@ class RNNTLoss(torch.nn.Module):
...
@@ -93,14 +87,12 @@ class RNNTLoss(torch.nn.Module):
blank
:
int
=
-
1
,
blank
:
int
=
-
1
,
clamp
:
float
=
-
1.
,
clamp
:
float
=
-
1.
,
fused_log_softmax
:
bool
=
True
,
fused_log_softmax
:
bool
=
True
,
reuse_logits_for_grads
:
bool
=
True
,
reduction
:
str
=
"mean"
,
reduction
:
str
=
"mean"
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
blank
=
blank
self
.
blank
=
blank
self
.
clamp
=
clamp
self
.
clamp
=
clamp
self
.
fused_log_softmax
=
fused_log_softmax
self
.
fused_log_softmax
=
fused_log_softmax
self
.
reuse_logits_for_grads
=
reuse_logits_for_grads
self
.
reduction
=
reduction
self
.
reduction
=
reduction
def
forward
(
def
forward
(
...
@@ -129,6 +121,5 @@ class RNNTLoss(torch.nn.Module):
...
@@ -129,6 +121,5 @@ class RNNTLoss(torch.nn.Module):
self
.
blank
,
self
.
blank
,
self
.
clamp
,
self
.
clamp
,
self
.
fused_log_softmax
,
self
.
fused_log_softmax
,
self
.
reuse_logits_for_grads
,
self
.
reduction
self
.
reduction
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment