Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
af7eb4d6
Unverified
Commit
af7eb4d6
authored
May 19, 2021
by
Caroline Chen
Committed by
GitHub
May 19, 2021
Browse files
Add torchscript support to RNNT Loss (#1507)
parent
079b3f5d
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
234 additions
and
107 deletions
+234
-107
test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py
...chaudio_unittest/rnnt/torchscript_consistency_cpu_test.py
+10
-0
test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py
...haudio_unittest/rnnt/torchscript_consistency_cuda_test.py
+11
-0
test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py
.../torchaudio_unittest/rnnt/torchscript_consistency_impl.py
+70
-0
test/torchaudio_unittest/rnnt/utils.py
test/torchaudio_unittest/rnnt/utils.py
+4
-4
torchaudio/csrc/CMakeLists.txt
torchaudio/csrc/CMakeLists.txt
+1
-0
torchaudio/csrc/rnnt/autograd.cpp
torchaudio/csrc/rnnt/autograd.cpp
+74
-0
torchaudio/csrc/rnnt/compute.cpp
torchaudio/csrc/rnnt/compute.cpp
+24
-0
torchaudio/csrc/rnnt/compute.h
torchaudio/csrc/rnnt/compute.h
+13
-0
torchaudio/prototype/rnnt_loss.py
torchaudio/prototype/rnnt_loss.py
+27
-103
No files found.
test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py
0 → 100644
View file @
af7eb4d6
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.utils
import
skipIfNoTransducer
from
.torchscript_consistency_impl
import
RNNTLossTorchscript
@
skipIfNoTransducer
class
TestRNNTLoss
(
RNNTLossTorchscript
,
PytorchTestCase
):
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py
0 → 100644
View file @
af7eb4d6
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
.utils
import
skipIfNoTransducer
from
.torchscript_consistency_impl
import
RNNTLossTorchscript
@
skipIfNoTransducer
@
skipIfNoCuda
class
TestRNNTLoss
(
RNNTLossTorchscript
,
PytorchTestCase
):
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py
0 → 100644
View file @
af7eb4d6
import
torch
from
torchaudio_unittest.common_utils
import
TempDirMixin
,
TestBaseMixin
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
,
rnnt_loss
class
RNNTLossTorchscript
(
TempDirMixin
,
TestBaseMixin
):
"""Implements test for RNNT Loss that are performed for different devices"""
def
_assert_consistency
(
self
,
func
,
tensor
,
shape_only
=
False
):
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
func
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
torch
.
random
.
manual_seed
(
40
)
input_tensor
=
tensor
.
clone
().
detach
().
requires_grad_
(
True
)
output
=
func
(
input_tensor
)
torch
.
random
.
manual_seed
(
40
)
input_tensor
=
tensor
.
clone
().
detach
().
requires_grad_
(
True
)
ts_output
=
ts_func
(
input_tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_rnnt_loss
(
self
):
def
func
(
logits
,
):
targets
=
torch
.
tensor
([[
1
,
2
]],
device
=
logits
.
device
,
dtype
=
torch
.
int32
)
logit_lengths
=
torch
.
tensor
([
2
],
device
=
logits
.
device
,
dtype
=
torch
.
int32
)
target_lengths
=
torch
.
tensor
([
2
],
device
=
logits
.
device
,
dtype
=
torch
.
int32
)
return
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
logits
=
torch
.
tensor
([[[[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.6
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.8
,
0.1
]],
[[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.1
,
0.1
],
[
0.7
,
0.1
,
0.2
,
0.1
,
0.1
]]]])
self
.
_assert_consistency
(
func
,
logits
)
def
test_RNNTLoss
(
self
):
func
=
RNNTLoss
()
logits
=
torch
.
tensor
([[[[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.6
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.8
,
0.1
]],
[[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.1
,
0.1
],
[
0.7
,
0.1
,
0.2
,
0.1
,
0.1
]]]])
targets
=
torch
.
tensor
([[
1
,
2
]],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
logit_lengths
=
torch
.
tensor
([
2
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
target_lengths
=
torch
.
tensor
([
2
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
tensor
=
logits
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
func
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
torch
.
random
.
manual_seed
(
40
)
input_tensor
=
tensor
.
clone
().
detach
().
requires_grad_
(
True
)
output
=
func
(
input_tensor
,
targets
,
logit_lengths
,
target_lengths
)
torch
.
random
.
manual_seed
(
40
)
input_tensor
=
tensor
.
clone
().
detach
().
requires_grad_
(
True
)
ts_output
=
ts_func
(
input_tensor
,
targets
,
logit_lengths
,
target_lengths
)
self
.
assertEqual
(
ts_output
,
output
)
test/torchaudio_unittest/rnnt/utils.py
View file @
af7eb4d6
...
...
@@ -405,10 +405,10 @@ def get_numpy_random_data(
def
numpy_to_torch
(
data
,
device
,
requires_grad
=
True
):
logits
=
torch
.
from_numpy
(
data
[
"logits"
])
targets
=
torch
.
from_numpy
(
data
[
"targets"
])
logit_lengths
=
torch
.
from_numpy
(
data
[
"logit_lengths"
])
target_lengths
=
torch
.
from_numpy
(
data
[
"target_lengths"
])
logits
=
torch
.
from_numpy
(
data
[
"logits"
])
.
to
(
device
=
device
)
targets
=
torch
.
from_numpy
(
data
[
"targets"
])
.
to
(
device
=
device
)
logit_lengths
=
torch
.
from_numpy
(
data
[
"logit_lengths"
])
.
to
(
device
=
device
)
target_lengths
=
torch
.
from_numpy
(
data
[
"target_lengths"
])
.
to
(
device
=
device
)
if
"nbest_wers"
in
data
:
data
[
"nbest_wers"
]
=
torch
.
from_numpy
(
data
[
"nbest_wers"
]).
to
(
device
=
device
)
...
...
torchaudio/csrc/CMakeLists.txt
View file @
af7eb4d6
...
...
@@ -19,6 +19,7 @@ if(BUILD_TRANSDUCER)
rnnt/compute_alphas.cpp
rnnt/compute_betas.cpp
rnnt/compute.cpp
rnnt/autograd.cpp
)
if
(
USE_CUDA
)
...
...
torchaudio/csrc/rnnt/autograd.cpp
0 → 100644
View file @
af7eb4d6
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
namespace
torchaudio
{
namespace
rnnt
{
class
RNNTLossFunction
:
public
torch
::
autograd
::
Function
<
RNNTLossFunction
>
{
public:
static
torch
::
autograd
::
tensor_list
forward
(
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src_lengths
,
const
torch
::
Tensor
&
tgt_lengths
,
int64_t
blank
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
at
::
AutoNonVariableTypeMode
g
;
torch
::
Tensor
undef
;
auto
result
=
rnnt_loss
(
logits
,
targets
,
src_lengths
,
tgt_lengths
,
blank
,
clamp
,
fused_log_smax
,
reuse_logits_for_grads
);
auto
costs
=
std
::
get
<
0
>
(
result
);
auto
grads
=
std
::
get
<
1
>
(
result
).
value_or
(
undef
);
ctx
->
save_for_backward
({
grads
});
return
{
costs
,
grads
};
}
static
torch
::
autograd
::
tensor_list
backward
(
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
autograd
::
tensor_list
grad_outputs
)
{
auto
saved
=
ctx
->
get_saved_variables
();
auto
grad
=
saved
[
0
];
auto
grad_out
=
grad_outputs
[
0
].
view
({
-
1
,
1
,
1
,
1
});
auto
result
=
grad
*
grad_out
;
torch
::
Tensor
undef
;
return
{
result
,
undef
,
undef
,
undef
,
undef
,
undef
,
undef
,
undef
};
}
};
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss_autograd
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src_lengths
,
const
torch
::
Tensor
&
tgt_lengths
,
int64_t
blank
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
auto
results
=
RNNTLossFunction
::
apply
(
logits
,
targets
,
src_lengths
,
tgt_lengths
,
blank
,
clamp
,
fused_log_smax
,
reuse_logits_for_grads
);
return
std
::
make_tuple
(
results
[
0
],
results
[
1
]);
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
Autograd
,
m
)
{
m
.
impl
(
"rnnt_loss"
,
rnnt_loss_autograd
);
}
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/compute.cpp
View file @
af7eb4d6
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src_lengths
,
const
torch
::
Tensor
&
tgt_lengths
,
int64_t
blank
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
static
auto
op
=
torch
::
Dispatcher
::
singleton
()
.
findSchemaOrThrow
(
"torchaudio::rnnt_loss"
,
""
)
.
typed
<
decltype
(
rnnt_loss
)
>
();
return
op
.
call
(
logits
,
targets
,
src_lengths
,
tgt_lengths
,
blank
,
clamp
,
fused_log_smax
,
reuse_logits_for_grads
);
}
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
m
.
def
(
...
...
torchaudio/csrc/rnnt/compute.h
0 → 100644
View file @
af7eb4d6
#pragma once
#include <torch/script.h>
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src_lengths
,
const
torch
::
Tensor
&
tgt_lengths
,
int64_t
blank
,
double
clamp
,
bool
fused_log_smax
,
bool
reuse_logits_for_grads
);
torchaudio/prototype/rnnt_loss.py
View file @
af7eb4d6
import
torch
from
torch
import
Tensor
__all__
=
[
"RNNTLoss"
,
...
...
@@ -19,15 +20,6 @@ def _rnnt_loss_alphas(
See documentation for RNNTLoss
"""
targets
=
targets
.
to
(
device
=
logits
.
device
)
logit_lengths
=
logit_lengths
.
to
(
device
=
logits
.
device
)
target_lengths
=
target_lengths
.
to
(
device
=
logits
.
device
)
# make sure all int tensors are of type int32.
targets
=
targets
.
int
()
logit_lengths
=
logit_lengths
.
int
()
target_lengths
=
target_lengths
.
int
()
return
torch
.
ops
.
torchaudio
.
rnnt_loss_alphas
(
logits
,
targets
,
...
...
@@ -51,15 +43,6 @@ def _rnnt_loss_betas(
See documentation for RNNTLoss
"""
targets
=
targets
.
to
(
device
=
logits
.
device
)
logit_lengths
=
logit_lengths
.
to
(
device
=
logits
.
device
)
target_lengths
=
target_lengths
.
to
(
device
=
logits
.
device
)
# make sure all int tensors are of type int32.
targets
=
targets
.
int
()
logit_lengths
=
logit_lengths
.
int
()
target_lengths
=
target_lengths
.
int
()
return
torch
.
ops
.
torchaudio
.
rnnt_loss_betas
(
logits
,
targets
,
...
...
@@ -70,77 +53,15 @@ def _rnnt_loss_betas(
)
class
_RNNT
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
=-
1
,
clamp
=-
1
,
fused_log_softmax
=
True
,
reuse_logits_for_grads
=
True
,
):
"""
See documentation for RNNTLoss
"""
# move everything to the same device.
targets
=
targets
.
to
(
device
=
logits
.
device
)
logit_lengths
=
logit_lengths
.
to
(
device
=
logits
.
device
)
target_lengths
=
target_lengths
.
to
(
device
=
logits
.
device
)
# make sure all int tensors are of type int32.
targets
=
targets
.
int
()
logit_lengths
=
logit_lengths
.
int
()
target_lengths
=
target_lengths
.
int
()
if
blank
<
0
:
# reinterpret blank index if blank < 0.
blank
=
logits
.
shape
[
-
1
]
+
blank
costs
,
gradients
=
torch
.
ops
.
torchaudio
.
rnnt_loss
(
logits
=
logits
,
targets
=
targets
,
src_lengths
=
logit_lengths
,
tgt_lengths
=
target_lengths
,
blank
=
blank
,
clamp
=
clamp
,
fused_log_smax
=
fused_log_softmax
,
reuse_logits_for_grads
=
reuse_logits_for_grads
,
)
ctx
.
grads
=
gradients
return
costs
@
staticmethod
def
backward
(
ctx
,
output_gradients
):
output_gradients
=
output_gradients
.
view
(
-
1
,
1
,
1
,
1
).
to
(
ctx
.
grads
)
ctx
.
grads
.
mul_
(
output_gradients
).
to
(
ctx
.
grads
)
return
(
ctx
.
grads
,
# logits
None
,
# targets
None
,
# logit_lengths
None
,
# target_lengths
None
,
# blank
None
,
# clamp
None
,
# fused_log_softmax
None
,
# reuse_logits_for_grads
)
def
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
=
-
1
,
clamp
=
-
1
,
fused_log_softmax
=
True
,
reuse_logits_for_grads
=
True
,
logits
:
Tensor
,
targets
:
Tensor
,
logit_lengths
:
Tensor
,
target_lengths
:
Tensor
,
blank
:
int
=
-
1
,
clamp
:
float
=
-
1
,
fused_log_softmax
:
bool
=
True
,
reuse_logits_for_grads
:
bool
=
True
,
):
"""
Compute the RNN Transducer Loss.
...
...
@@ -166,17 +87,20 @@ def rnnt_loss(
False
# softmax needs the original logits value
)
cost
=
_RNNT
.
apply
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
,
fused_log_softmax
,
reuse_logits_for_grads
,
)
return
cost
if
blank
<
0
:
# reinterpret blank index if blank < 0.
blank
=
logits
.
shape
[
-
1
]
+
blank
costs
,
gradients
=
torch
.
ops
.
torchaudio
.
rnnt_loss
(
logits
=
logits
,
targets
=
targets
,
src_lengths
=
logit_lengths
,
tgt_lengths
=
target_lengths
,
blank
=
blank
,
clamp
=
clamp
,
fused_log_smax
=
fused_log_softmax
,
reuse_logits_for_grads
=
reuse_logits_for_grads
,)
return
costs
class
RNNTLoss
(
torch
.
nn
.
Module
):
...
...
@@ -196,10 +120,10 @@ class RNNTLoss(torch.nn.Module):
def
__init__
(
self
,
blank
=
-
1
,
clamp
=
-
1
,
fused_log_softmax
=
True
,
reuse_logits_for_grads
=
True
,
blank
:
int
=
-
1
,
clamp
:
float
=
-
1
.
,
fused_log_softmax
:
bool
=
True
,
reuse_logits_for_grads
:
bool
=
True
,
):
super
().
__init__
()
self
.
blank
=
blank
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment