Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
2c115821
"stubs/vscode:/vscode.git/clone" did not exist on "bde4bac5fe3f6b040ac6d75e3bd631be7f504c27"
Unverified
Commit
2c115821
authored
Aug 19, 2021
by
Caroline Chen
Committed by
GitHub
Aug 19, 2021
Browse files
Move RNNT Loss out of prototype (#1711)
parent
b7d44d97
Changes
40
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
183 additions
and
605 deletions
+183
-605
test/torchaudio_unittest/rnnt/autograd_cpu_test.py
test/torchaudio_unittest/rnnt/autograd_cpu_test.py
+0
-10
test/torchaudio_unittest/rnnt/autograd_cuda_test.py
test/torchaudio_unittest/rnnt/autograd_cuda_test.py
+0
-11
test/torchaudio_unittest/rnnt/autograd_impl.py
test/torchaudio_unittest/rnnt/autograd_impl.py
+0
-93
test/torchaudio_unittest/rnnt/numpy_transducer.py
test/torchaudio_unittest/rnnt/numpy_transducer.py
+0
-168
test/torchaudio_unittest/rnnt/rnnt_loss_cpu_test.py
test/torchaudio_unittest/rnnt/rnnt_loss_cpu_test.py
+0
-9
test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py
test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py
+0
-10
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
+0
-87
test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py
...chaudio_unittest/rnnt/torchscript_consistency_cpu_test.py
+0
-10
test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py
...haudio_unittest/rnnt/torchscript_consistency_cuda_test.py
+0
-11
test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py
.../torchaudio_unittest/rnnt/torchscript_consistency_impl.py
+0
-70
test/torchaudio_unittest/transforms/autograd_cpu_test.py
test/torchaudio_unittest/transforms/autograd_cpu_test.py
+5
-1
test/torchaudio_unittest/transforms/autograd_cuda_test.py
test/torchaudio_unittest/transforms/autograd_cuda_test.py
+6
-1
test/torchaudio_unittest/transforms/autograd_test_impl.py
test/torchaudio_unittest/transforms/autograd_test_impl.py
+39
-0
test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py
...o_unittest/transforms/torchscript_consistency_cpu_test.py
+2
-2
test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py
..._unittest/transforms/torchscript_consistency_cuda_test.py
+2
-2
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
...audio_unittest/transforms/torchscript_consistency_impl.py
+19
-3
torchaudio/functional/__init__.py
torchaudio/functional/__init__.py
+2
-0
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+53
-0
torchaudio/prototype/rnnt_loss.py
torchaudio/prototype/rnnt_loss.py
+0
-117
torchaudio/transforms.py
torchaudio/transforms.py
+55
-0
No files found.
test/torchaudio_unittest/rnnt/autograd_cpu_test.py
deleted
100644 → 0
View file @
b7d44d97
import
torch
from
.autograd_impl
import
Autograd
from
torchaudio_unittest
import
common_utils
from
.utils
import
skipIfNoRNNT
@
skipIfNoRNNT
class
TestAutograd
(
Autograd
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/rnnt/autograd_cuda_test.py
deleted
100644 → 0
View file @
b7d44d97
import
torch
from
.autograd_impl
import
Autograd
from
torchaudio_unittest
import
common_utils
from
.utils
import
skipIfNoRNNT
@
skipIfNoRNNT
@
common_utils
.
skipIfNoCuda
class
TestAutograd
(
Autograd
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/rnnt/autograd_impl.py
deleted
100644 → 0
View file @
b7d44d97
from
typing
import
Callable
,
Tuple
import
torch
from
torch
import
Tensor
from
torch.autograd
import
gradcheck
from
torchaudio_unittest.common_utils
import
(
TestBaseMixin
,
)
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
,
rnnt_loss
from
parameterized
import
parameterized
from
.utils
import
(
get_B1_T10_U3_D4_data
,
get_B2_T4_U3_D3_data
,
get_B1_T2_U3_D5_data
)
from
.numpy_transducer
import
NumpyTransducerLoss
class
Autograd
(
TestBaseMixin
):
@
staticmethod
def
get_data
(
data_func
,
device
):
data
=
data_func
()
if
type
(
data
)
==
tuple
:
data
=
data
[
0
]
return
data
def
assert_grad
(
self
,
loss
:
Callable
[...,
Tensor
],
inputs
:
Tuple
[
torch
.
Tensor
],
*
,
enable_all_grad
:
bool
=
True
,
):
inputs_
=
[]
for
i
in
inputs
:
if
torch
.
is_tensor
(
i
):
i
=
i
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
if
enable_all_grad
:
i
.
requires_grad
=
True
inputs_
.
append
(
i
)
# gradcheck with float32 requires higher atol and epsilon
assert
gradcheck
(
loss
,
inputs
,
eps
=
1e-3
,
atol
=
1e-3
,
nondet_tol
=
0.
)
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_B2_T4_U3_D3_data
,
),
(
get_B1_T2_U3_D5_data
,
),
])
def
test_RNNTLoss_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
inputs
=
(
data
[
"logits"
].
to
(
self
.
dtype
),
data
[
"targets"
],
data
[
"logit_lengths"
],
data
[
"target_lengths"
],
)
loss
=
RNNTLoss
(
blank
=
data
[
"blank"
])
self
.
assert_grad
(
loss
,
inputs
,
enable_all_grad
=
False
)
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_B2_T4_U3_D3_data
,
),
(
get_B1_T2_U3_D5_data
,
),
])
def
test_rnnt_loss_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
inputs
=
(
data
[
"logits"
].
to
(
self
.
dtype
),
# logits
data
[
"targets"
],
# targets
data
[
"logit_lengths"
],
# logit_lengths
data
[
"target_lengths"
],
# target_lengths
data
[
"blank"
],
# blank
-
1
,
# clamp
)
self
.
assert_grad
(
rnnt_loss
,
inputs
,
enable_all_grad
=
False
)
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_B2_T4_U3_D3_data
,
),
(
get_B1_T2_U3_D5_data
,
),
])
def
test_np_transducer_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
inputs
=
(
data
[
"logits"
].
to
(
self
.
dtype
),
data
[
"logit_lengths"
],
data
[
"target_lengths"
],
data
[
"targets"
],
)
loss
=
NumpyTransducerLoss
(
blank
=
data
[
"blank"
])
self
.
assert_grad
(
loss
,
inputs
,
enable_all_grad
=
False
)
test/torchaudio_unittest/rnnt/numpy_transducer.py
deleted
100644 → 0
View file @
b7d44d97
import
numpy
as
np
import
torch
class
_NumpyTransducer
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
log_probs
,
logit_lengths
,
target_lengths
,
targets
,
blank
=-
1
,
):
device
=
log_probs
.
device
log_probs
=
log_probs
.
cpu
().
data
.
numpy
()
logit_lengths
=
logit_lengths
.
cpu
().
data
.
numpy
()
target_lengths
=
target_lengths
.
cpu
().
data
.
numpy
()
targets
=
targets
.
cpu
().
data
.
numpy
()
gradients
,
costs
,
_
,
_
=
__class__
.
compute
(
log_probs
=
log_probs
,
logit_lengths
=
logit_lengths
,
target_lengths
=
target_lengths
,
targets
=
targets
,
blank
=
blank
,
)
costs
=
torch
.
FloatTensor
(
costs
).
to
(
device
=
device
)
gradients
=
torch
.
FloatTensor
(
gradients
).
to
(
device
=
device
)
ctx
.
grads
=
torch
.
autograd
.
Variable
(
gradients
)
return
costs
@
staticmethod
def
backward
(
ctx
,
grad_output
):
grad_output
=
grad_output
.
view
(
-
1
,
1
,
1
,
1
).
to
(
ctx
.
grads
)
return
ctx
.
grads
.
mul
(
grad_output
),
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
@
staticmethod
def
compute_alpha_one_sequence
(
log_probs
,
targets
,
blank
=-
1
):
max_T
,
max_U
,
D
=
log_probs
.
shape
alpha
=
np
.
zeros
((
max_T
,
max_U
),
dtype
=
np
.
float32
)
for
t
in
range
(
1
,
max_T
):
alpha
[
t
,
0
]
=
alpha
[
t
-
1
,
0
]
+
log_probs
[
t
-
1
,
0
,
blank
]
for
u
in
range
(
1
,
max_U
):
alpha
[
0
,
u
]
=
alpha
[
0
,
u
-
1
]
+
log_probs
[
0
,
u
-
1
,
targets
[
u
-
1
]]
for
t
in
range
(
1
,
max_T
):
for
u
in
range
(
1
,
max_U
):
skip
=
alpha
[
t
-
1
,
u
]
+
log_probs
[
t
-
1
,
u
,
blank
]
emit
=
alpha
[
t
,
u
-
1
]
+
log_probs
[
t
,
u
-
1
,
targets
[
u
-
1
]]
alpha
[
t
,
u
]
=
np
.
logaddexp
(
skip
,
emit
)
cost
=
-
(
alpha
[
-
1
,
-
1
]
+
log_probs
[
-
1
,
-
1
,
blank
])
return
alpha
,
cost
@
staticmethod
def
compute_beta_one_sequence
(
log_probs
,
targets
,
blank
=-
1
):
max_T
,
max_U
,
D
=
log_probs
.
shape
beta
=
np
.
zeros
((
max_T
,
max_U
),
dtype
=
np
.
float32
)
beta
[
-
1
,
-
1
]
=
log_probs
[
-
1
,
-
1
,
blank
]
for
t
in
reversed
(
range
(
max_T
-
1
)):
beta
[
t
,
-
1
]
=
beta
[
t
+
1
,
-
1
]
+
log_probs
[
t
,
-
1
,
blank
]
for
u
in
reversed
(
range
(
max_U
-
1
)):
beta
[
-
1
,
u
]
=
beta
[
-
1
,
u
+
1
]
+
log_probs
[
-
1
,
u
,
targets
[
u
]]
for
t
in
reversed
(
range
(
max_T
-
1
)):
for
u
in
reversed
(
range
(
max_U
-
1
)):
skip
=
beta
[
t
+
1
,
u
]
+
log_probs
[
t
,
u
,
blank
]
emit
=
beta
[
t
,
u
+
1
]
+
log_probs
[
t
,
u
,
targets
[
u
]]
beta
[
t
,
u
]
=
np
.
logaddexp
(
skip
,
emit
)
cost
=
-
beta
[
0
,
0
]
return
beta
,
cost
@
staticmethod
def
compute_gradients_one_sequence
(
log_probs
,
alpha
,
beta
,
targets
,
blank
=-
1
):
max_T
,
max_U
,
D
=
log_probs
.
shape
gradients
=
np
.
full
(
log_probs
.
shape
,
float
(
"-inf"
))
cost
=
-
beta
[
0
,
0
]
gradients
[
-
1
,
-
1
,
blank
]
=
alpha
[
-
1
,
-
1
]
gradients
[:
-
1
,
:,
blank
]
=
alpha
[:
-
1
,
:]
+
beta
[
1
:,
:]
for
u
,
l
in
enumerate
(
targets
):
gradients
[:,
u
,
l
]
=
alpha
[:,
u
]
+
beta
[:,
u
+
1
]
gradients
=
-
(
np
.
exp
(
gradients
+
log_probs
+
cost
))
return
gradients
@
staticmethod
def
compute
(
log_probs
,
logit_lengths
,
target_lengths
,
targets
,
blank
=-
1
,
):
gradients
=
np
.
zeros_like
(
log_probs
)
B_tgt
,
max_T
,
max_U
,
D
=
log_probs
.
shape
B_src
=
logit_lengths
.
shape
[
0
]
H
=
int
(
B_tgt
/
B_src
)
alphas
=
np
.
zeros
((
B_tgt
,
max_T
,
max_U
))
betas
=
np
.
zeros
((
B_tgt
,
max_T
,
max_U
))
betas
.
fill
(
float
(
"-inf"
))
alphas
.
fill
(
float
(
"-inf"
))
costs
=
np
.
zeros
(
B_tgt
)
for
b_tgt
in
range
(
B_tgt
):
b_src
=
int
(
b_tgt
/
H
)
T
=
int
(
logit_lengths
[
b_src
])
# NOTE: see https://arxiv.org/pdf/1211.3711.pdf Section 2.1
U
=
int
(
target_lengths
[
b_tgt
])
+
1
seq_log_probs
=
log_probs
[
b_tgt
,
:
T
,
:
U
,
:]
seq_targets
=
targets
[
b_tgt
,
:
int
(
target_lengths
[
b_tgt
])]
alpha
,
alpha_cost
=
__class__
.
compute_alpha_one_sequence
(
log_probs
=
seq_log_probs
,
targets
=
seq_targets
,
blank
=
blank
)
beta
,
beta_cost
=
__class__
.
compute_beta_one_sequence
(
log_probs
=
seq_log_probs
,
targets
=
seq_targets
,
blank
=
blank
)
seq_gradients
=
__class__
.
compute_gradients_one_sequence
(
log_probs
=
seq_log_probs
,
alpha
=
alpha
,
beta
=
beta
,
targets
=
seq_targets
,
blank
=
blank
,
)
np
.
testing
.
assert_almost_equal
(
alpha_cost
,
beta_cost
,
decimal
=
2
)
gradients
[
b_tgt
,
:
T
,
:
U
,
:]
=
seq_gradients
costs
[
b_tgt
]
=
beta_cost
alphas
[
b_tgt
,
:
T
,
:
U
]
=
alpha
betas
[
b_tgt
,
:
T
,
:
U
]
=
beta
return
gradients
,
costs
,
alphas
,
betas
class
NumpyTransducerLoss
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
blank
=-
1
):
super
().
__init__
()
self
.
blank
=
blank
def
forward
(
self
,
logits
,
logit_lengths
,
target_lengths
,
targets
,
):
log_probs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
return
_NumpyTransducer
.
apply
(
log_probs
,
logit_lengths
,
target_lengths
,
targets
,
self
.
blank
,
)
test/torchaudio_unittest/rnnt/rnnt_loss_cpu_test.py
deleted
100644 → 0
View file @
b7d44d97
import
torch
from
torchaudio_unittest
import
common_utils
from
.utils
import
skipIfNoRNNT
from
.rnnt_loss_impl
import
RNNTLossTest
@
skipIfNoRNNT
class
TestRNNTLoss
(
RNNTLossTest
,
common_utils
.
PytorchTestCase
):
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py
deleted
100644 → 0
View file @
b7d44d97
import
torch
from
.rnnt_loss_impl
import
RNNTLossTest
from
torchaudio_unittest
import
common_utils
from
.utils
import
skipIfNoRNNT
@
skipIfNoRNNT
@
common_utils
.
skipIfNoCuda
class
TestRNNTLoss
(
RNNTLossTest
,
common_utils
.
PytorchTestCase
):
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
deleted
100644 → 0
View file @
b7d44d97
import
torch
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
from
.utils
import
(
compute_with_numpy_transducer
,
compute_with_pytorch_transducer
,
get_basic_data
,
get_B1_T2_U3_D5_data
,
get_B2_T4_U3_D3_data
,
get_random_data
,
)
class
RNNTLossTest
:
def
_test_costs_and_gradients
(
self
,
data
,
ref_costs
,
ref_gradients
,
atol
=
1e-6
,
rtol
=
1e-2
):
logits_shape
=
data
[
"logits"
].
shape
costs
,
gradients
=
compute_with_pytorch_transducer
(
data
=
data
)
self
.
assertEqual
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
logits_shape
,
gradients
.
shape
)
self
.
assertEqual
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
)
def
test_basic_backward
(
self
):
rnnt_loss
=
RNNTLoss
()
logits
,
targets
,
logit_lengths
,
target_lengths
=
get_basic_data
(
self
.
device
)
loss
=
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
loss
.
backward
()
def
test_basic_forward_no_grad
(
self
):
rnnt_loss
=
RNNTLoss
()
logits
,
targets
,
logit_lengths
,
target_lengths
=
get_basic_data
(
self
.
device
)
logits
.
requires_grad_
(
False
)
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
def
test_costs_and_gradients_B1_T2_U3_D5_fp32
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_B1_T2_U3_D5_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
def
test_costs_and_gradients_B1_T2_U3_D5_fp16
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_B1_T2_U3_D5_data
(
dtype
=
torch
.
float16
,
device
=
self
.
device
,
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
,
atol
=
1e-3
,
rtol
=
1e-2
,
)
def
test_costs_and_gradients_B2_T4_U3_D3_fp32
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_B2_T4_U3_D3_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
def
test_costs_and_gradients_B2_T4_U3_D3_fp16
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_B2_T4_U3_D3_data
(
dtype
=
torch
.
float16
,
device
=
self
.
device
,
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
,
atol
=
1e-3
,
rtol
=
1e-2
,
)
def
test_costs_and_gradients_random_data_with_numpy_fp32
(
self
):
seed
=
777
for
i
in
range
(
5
):
data
=
get_random_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
seed
=
(
seed
+
i
))
ref_costs
,
ref_gradients
=
compute_with_numpy_transducer
(
data
=
data
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py
deleted
100644 → 0
View file @
b7d44d97
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.utils
import
skipIfNoRNNT
from
.torchscript_consistency_impl
import
RNNTLossTorchscript
@
skipIfNoRNNT
class
TestRNNTLoss
(
RNNTLossTorchscript
,
PytorchTestCase
):
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py
deleted
100644 → 0
View file @
b7d44d97
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
.utils
import
skipIfNoRNNT
from
.torchscript_consistency_impl
import
RNNTLossTorchscript
@
skipIfNoRNNT
@
skipIfNoCuda
class
TestRNNTLoss
(
RNNTLossTorchscript
,
PytorchTestCase
):
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py
deleted
100644 → 0
View file @
b7d44d97
import
torch
from
torchaudio_unittest.common_utils
import
TempDirMixin
,
TestBaseMixin
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
,
rnnt_loss
class
RNNTLossTorchscript
(
TempDirMixin
,
TestBaseMixin
):
"""Implements test for RNNT Loss that are performed for different devices"""
def
_assert_consistency
(
self
,
func
,
tensor
,
shape_only
=
False
):
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
func
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
torch
.
random
.
manual_seed
(
40
)
input_tensor
=
tensor
.
clone
().
detach
().
requires_grad_
(
True
)
output
=
func
(
input_tensor
)
torch
.
random
.
manual_seed
(
40
)
input_tensor
=
tensor
.
clone
().
detach
().
requires_grad_
(
True
)
ts_output
=
ts_func
(
input_tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_rnnt_loss
(
self
):
def
func
(
logits
,
):
targets
=
torch
.
tensor
([[
1
,
2
]],
device
=
logits
.
device
,
dtype
=
torch
.
int32
)
logit_lengths
=
torch
.
tensor
([
2
],
device
=
logits
.
device
,
dtype
=
torch
.
int32
)
target_lengths
=
torch
.
tensor
([
2
],
device
=
logits
.
device
,
dtype
=
torch
.
int32
)
return
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
logits
=
torch
.
tensor
([[[[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.6
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.8
,
0.1
]],
[[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.1
,
0.1
],
[
0.7
,
0.1
,
0.2
,
0.1
,
0.1
]]]])
self
.
_assert_consistency
(
func
,
logits
)
def
test_RNNTLoss
(
self
):
func
=
RNNTLoss
()
logits
=
torch
.
tensor
([[[[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.6
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.8
,
0.1
]],
[[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.1
,
0.1
],
[
0.7
,
0.1
,
0.2
,
0.1
,
0.1
]]]])
targets
=
torch
.
tensor
([[
1
,
2
]],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
logit_lengths
=
torch
.
tensor
([
2
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
target_lengths
=
torch
.
tensor
([
2
],
device
=
self
.
device
,
dtype
=
torch
.
int32
)
tensor
=
logits
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
func
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
torch
.
random
.
manual_seed
(
40
)
input_tensor
=
tensor
.
clone
().
detach
().
requires_grad_
(
True
)
output
=
func
(
input_tensor
,
targets
,
logit_lengths
,
target_lengths
)
torch
.
random
.
manual_seed
(
40
)
input_tensor
=
tensor
.
clone
().
detach
().
requires_grad_
(
True
)
ts_output
=
ts_func
(
input_tensor
,
targets
,
logit_lengths
,
target_lengths
)
self
.
assertEqual
(
ts_output
,
output
)
test/torchaudio_unittest/transforms/autograd_cpu_test.py
View file @
2c115821
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.autograd_test_impl
import
AutogradTestMixin
from
.autograd_test_impl
import
AutogradTestMixin
,
AutogradTestFloat32
class
AutogradCPUTest
(
AutogradTestMixin
,
PytorchTestCase
):
class
AutogradCPUTest
(
AutogradTestMixin
,
PytorchTestCase
):
device
=
'cpu'
device
=
'cpu'
class
AutogradRNNTCPUTest
(
AutogradTestFloat32
,
PytorchTestCase
):
device
=
'cpu'
test/torchaudio_unittest/transforms/autograd_cuda_test.py
View file @
2c115821
...
@@ -2,9 +2,14 @@ from torchaudio_unittest.common_utils import (
...
@@ -2,9 +2,14 @@ from torchaudio_unittest.common_utils import (
PytorchTestCase
,
PytorchTestCase
,
skipIfNoCuda
,
skipIfNoCuda
,
)
)
from
.autograd_test_impl
import
AutogradTestMixin
from
.autograd_test_impl
import
AutogradTestMixin
,
AutogradTestFloat32
@
skipIfNoCuda
@
skipIfNoCuda
class
AutogradCUDATest
(
AutogradTestMixin
,
PytorchTestCase
):
class
AutogradCUDATest
(
AutogradTestMixin
,
PytorchTestCase
):
device
=
'cuda'
device
=
'cuda'
@
skipIfNoCuda
class
AutogradRNNTCUDATest
(
AutogradTestFloat32
,
PytorchTestCase
):
device
=
'cuda'
test/torchaudio_unittest/transforms/autograd_test_impl.py
View file @
2c115821
...
@@ -11,6 +11,7 @@ from torchaudio_unittest.common_utils import (
...
@@ -11,6 +11,7 @@ from torchaudio_unittest.common_utils import (
get_whitenoise
,
get_whitenoise
,
get_spectrogram
,
get_spectrogram
,
nested_params
,
nested_params
,
rnnt_utils
,
)
)
...
@@ -260,3 +261,41 @@ class AutogradTestMixin(TestBaseMixin):
...
@@ -260,3 +261,41 @@ class AutogradTestMixin(TestBaseMixin):
if
test_pseudo_complex
:
if
test_pseudo_complex
:
spectrogram
=
torch
.
view_as_real
(
spectrogram
)
spectrogram
=
torch
.
view_as_real
(
spectrogram
)
self
.
assert_grad
(
transform
,
[
spectrogram
])
self
.
assert_grad
(
transform
,
[
spectrogram
])
class
AutogradTestFloat32
(
TestBaseMixin
):
def
assert_grad
(
self
,
transform
:
torch
.
nn
.
Module
,
inputs
:
List
[
torch
.
Tensor
],
):
inputs_
=
[]
for
i
in
inputs
:
if
torch
.
is_tensor
(
i
):
i
=
i
.
to
(
dtype
=
torch
.
float32
,
device
=
self
.
device
)
inputs_
.
append
(
i
)
# gradcheck with float32 requires higher atol and epsilon
assert
gradcheck
(
transform
,
inputs
,
eps
=
1e-3
,
atol
=
1e-3
,
nondet_tol
=
0.
)
@
parameterized
.
expand
([
(
rnnt_utils
.
get_B1_T10_U3_D4_data
,
),
(
rnnt_utils
.
get_B2_T4_U3_D3_data
,
),
(
rnnt_utils
.
get_B1_T2_U3_D5_data
,
),
])
def
test_rnnt_loss
(
self
,
data_func
):
def
get_data
(
data_func
,
device
):
data
=
data_func
()
if
type
(
data
)
==
tuple
:
data
=
data
[
0
]
return
data
data
=
get_data
(
data_func
,
self
.
device
)
inputs
=
(
data
[
"logits"
].
to
(
torch
.
float32
),
data
[
"targets"
],
data
[
"logit_lengths"
],
data
[
"target_lengths"
],
)
loss
=
T
.
RNNTLoss
(
blank
=
data
[
"blank"
])
self
.
assert_grad
(
loss
,
inputs
)
test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py
View file @
2c115821
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
from
.torchscript_consistency_impl
import
Transforms
,
TransformsFloat32Only
class
TestTransformsFloat32
(
Transforms
,
PytorchTestCase
):
class
TestTransformsFloat32
(
Transforms
,
TransformsFloat32Only
,
PytorchTestCase
):
dtype
=
torch
.
float32
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
...
...
test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py
View file @
2c115821
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
from
.torchscript_consistency_impl
import
Transforms
,
TransformsFloat32Only
@
skipIfNoCuda
@
skipIfNoCuda
class
TestTransformsFloat32
(
Transforms
,
PytorchTestCase
):
class
TestTransformsFloat32
(
Transforms
,
TransformsFloat32Only
,
PytorchTestCase
):
dtype
=
torch
.
float32
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
...
...
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
View file @
2c115821
...
@@ -14,7 +14,7 @@ from torchaudio_unittest.common_utils import (
...
@@ -14,7 +14,7 @@ from torchaudio_unittest.common_utils import (
class
Transforms
(
TempDirMixin
,
TestBaseMixin
):
class
Transforms
(
TempDirMixin
,
TestBaseMixin
):
"""Implements test for Transforms that are performed for different devices"""
"""Implements test for Transforms that are performed for different devices"""
def
_assert_consistency
(
self
,
transform
,
tensor
):
def
_assert_consistency
(
self
,
transform
,
tensor
,
*
args
):
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
@@ -22,8 +22,8 @@ class Transforms(TempDirMixin, TestBaseMixin):
...
@@ -22,8 +22,8 @@ class Transforms(TempDirMixin, TestBaseMixin):
torch
.
jit
.
script
(
transform
).
save
(
path
)
torch
.
jit
.
script
(
transform
).
save
(
path
)
ts_transform
=
torch
.
jit
.
load
(
path
)
ts_transform
=
torch
.
jit
.
load
(
path
)
output
=
transform
(
tensor
)
output
=
transform
(
tensor
,
*
args
)
ts_output
=
ts_transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
,
*
args
)
self
.
assertEqual
(
ts_output
,
output
)
self
.
assertEqual
(
ts_output
,
output
)
def
_assert_consistency_complex
(
self
,
transform
,
tensor
,
test_pseudo_complex
=
False
):
def
_assert_consistency_complex
(
self
,
transform
,
tensor
,
test_pseudo_complex
=
False
):
...
@@ -155,3 +155,19 @@ class Transforms(TempDirMixin, TestBaseMixin):
...
@@ -155,3 +155,19 @@ class Transforms(TempDirMixin, TestBaseMixin):
T
.
PitchShift
(
sample_rate
=
sample_rate
,
n_steps
=
n_steps
),
T
.
PitchShift
(
sample_rate
=
sample_rate
,
n_steps
=
n_steps
),
waveform
waveform
)
)
class
TransformsFloat32Only
(
TestBaseMixin
):
def
test_rnnt_loss
(
self
):
logits
=
torch
.
tensor
([[[[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.6
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.8
,
0.1
]],
[[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.1
,
0.1
],
[
0.7
,
0.1
,
0.2
,
0.1
,
0.1
]]]])
tensor
=
logits
.
to
(
device
=
self
.
device
,
dtype
=
torch
.
float32
)
targets
=
torch
.
tensor
([[
1
,
2
]],
device
=
tensor
.
device
,
dtype
=
torch
.
int32
)
logit_lengths
=
torch
.
tensor
([
2
],
device
=
tensor
.
device
,
dtype
=
torch
.
int32
)
target_lengths
=
torch
.
tensor
([
2
],
device
=
tensor
.
device
,
dtype
=
torch
.
int32
)
self
.
_assert_consistency
(
T
.
RNNTLoss
(),
logits
,
targets
,
logit_lengths
,
target_lengths
)
torchaudio/functional/__init__.py
View file @
2c115821
...
@@ -25,6 +25,7 @@ from .functional import (
...
@@ -25,6 +25,7 @@ from .functional import (
resample
,
resample
,
edit_distance
,
edit_distance
,
pitch_shift
,
pitch_shift
,
rnnt_loss
,
)
)
from
.filtering
import
(
from
.filtering
import
(
allpass_biquad
,
allpass_biquad
,
...
@@ -98,4 +99,5 @@ __all__ = [
...
@@ -98,4 +99,5 @@ __all__ = [
'resample'
,
'resample'
,
'edit_distance'
,
'edit_distance'
,
'pitch_shift'
,
'pitch_shift'
,
'rnnt_loss'
,
]
]
torchaudio/functional/functional.py
View file @
2c115821
...
@@ -40,6 +40,7 @@ __all__ = [
...
@@ -40,6 +40,7 @@ __all__ = [
"resample"
,
"resample"
,
"edit_distance"
,
"edit_distance"
,
"pitch_shift"
,
"pitch_shift"
,
"rnnt_loss"
,
]
]
...
@@ -1745,3 +1746,55 @@ def pitch_shift(
...
@@ -1745,3 +1746,55 @@ def pitch_shift(
# unpack batch
# unpack batch
waveform_shift
=
waveform_shift
.
view
(
shape
[:
-
1
]
+
waveform_shift
.
shape
[
-
1
:])
waveform_shift
=
waveform_shift
.
view
(
shape
[:
-
1
]
+
waveform_shift
.
shape
[
-
1
:])
return
waveform_shift
return
waveform_shift
def
rnnt_loss
(
logits
:
Tensor
,
targets
:
Tensor
,
logit_lengths
:
Tensor
,
target_lengths
:
Tensor
,
blank
:
int
=
-
1
,
clamp
:
float
=
-
1
,
reduction
:
str
=
"mean"
,
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
[:footcite:`graves2012sequence`].
The RNN Transducer loss extends the CTC loss by defining a distribution over output
sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
Args:
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
if
reduction
not
in
[
'none'
,
'mean'
,
'sum'
]:
raise
ValueError
(
"reduction should be one of 'none', 'mean', or 'sum'"
)
if
blank
<
0
:
# reinterpret blank index if blank < 0.
blank
=
logits
.
shape
[
-
1
]
+
blank
costs
,
_
=
torch
.
ops
.
torchaudio
.
rnnt_loss
(
logits
=
logits
,
targets
=
targets
,
logit_lengths
=
logit_lengths
,
target_lengths
=
target_lengths
,
blank
=
blank
,
clamp
=
clamp
,
)
if
reduction
==
'mean'
:
return
costs
.
mean
()
elif
reduction
==
'sum'
:
return
costs
.
sum
()
return
costs
torchaudio/prototype/rnnt_loss.py
deleted
100644 → 0
View file @
b7d44d97
import
torch
from
torch
import
Tensor
__all__
=
[
"RNNTLoss"
,
"rnnt_loss"
,
]
def
rnnt_loss
(
logits
:
Tensor
,
targets
:
Tensor
,
logit_lengths
:
Tensor
,
target_lengths
:
Tensor
,
blank
:
int
=
-
1
,
clamp
:
float
=
-
1
,
reduction
:
str
=
"mean"
,
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
[:footcite:`graves2012sequence`].
The RNN Transducer loss extends the CTC loss by defining a distribution over output
sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
Args:
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
if
reduction
not
in
[
'none'
,
'mean'
,
'sum'
]:
raise
ValueError
(
"reduction should be one of 'none', 'mean', or 'sum'"
)
if
blank
<
0
:
# reinterpret blank index if blank < 0.
blank
=
logits
.
shape
[
-
1
]
+
blank
costs
,
_
=
torch
.
ops
.
torchaudio
.
rnnt_loss
(
logits
=
logits
,
targets
=
targets
,
logit_lengths
=
logit_lengths
,
target_lengths
=
target_lengths
,
blank
=
blank
,
clamp
=
clamp
,
)
if
reduction
==
'mean'
:
return
costs
.
mean
()
elif
reduction
==
'sum'
:
return
costs
.
sum
()
return
costs
class
RNNTLoss
(
torch
.
nn
.
Module
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
[:footcite:`graves2012sequence`].
The RNN Transducer loss extends the CTC loss by defining a distribution over output
sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
Args:
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
def
__init__
(
self
,
blank
:
int
=
-
1
,
clamp
:
float
=
-
1.
,
reduction
:
str
=
"mean"
,
):
super
().
__init__
()
self
.
blank
=
blank
self
.
clamp
=
clamp
self
.
reduction
=
reduction
def
forward
(
self
,
logits
,
targets
,
logit_lengths
,
target_lengths
,
):
"""
Args:
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
return
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
self
.
blank
,
self
.
clamp
,
self
.
reduction
)
torchaudio/transforms.py
View file @
2c115821
...
@@ -37,6 +37,7 @@ __all__ = [
...
@@ -37,6 +37,7 @@ __all__ = [
'Vol'
,
'Vol'
,
'ComputeDeltas'
,
'ComputeDeltas'
,
'PitchShift'
,
'PitchShift'
,
'RNNTLoss'
,
]
]
...
@@ -1428,3 +1429,57 @@ class PitchShift(torch.nn.Module):
...
@@ -1428,3 +1429,57 @@ class PitchShift(torch.nn.Module):
return
F
.
pitch_shift
(
waveform
,
self
.
sample_rate
,
self
.
n_steps
,
self
.
bins_per_octave
,
self
.
n_fft
,
return
F
.
pitch_shift
(
waveform
,
self
.
sample_rate
,
self
.
n_steps
,
self
.
bins_per_octave
,
self
.
n_fft
,
self
.
win_length
,
self
.
hop_length
,
self
.
window
)
self
.
win_length
,
self
.
hop_length
,
self
.
window
)
class
RNNTLoss
(
torch
.
nn
.
Module
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
[:footcite:`graves2012sequence`].
The RNN Transducer loss extends the CTC loss by defining a distribution over output
sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
Args:
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
def
__init__
(
self
,
blank
:
int
=
-
1
,
clamp
:
float
=
-
1.
,
reduction
:
str
=
"mean"
,
):
super
().
__init__
()
self
.
blank
=
blank
self
.
clamp
=
clamp
self
.
reduction
=
reduction
def
forward
(
self
,
logits
:
Tensor
,
targets
:
Tensor
,
logit_lengths
:
Tensor
,
target_lengths
:
Tensor
,
):
"""
Args:
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
return
F
.
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
self
.
blank
,
self
.
clamp
,
self
.
reduction
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment