Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
0ea6d10d
Unverified
Commit
0ea6d10d
authored
Jul 15, 2021
by
Caroline Chen
Committed by
GitHub
Jul 15, 2021
Browse files
Refactor RNNT Loss Unit Tests (#1630)
parent
56ab0368
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
120 additions
and
152 deletions
+120
-152
test/torchaudio_unittest/rnnt/autograd_impl.py
test/torchaudio_unittest/rnnt/autograd_impl.py
+11
-15
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
+23
-38
test/torchaudio_unittest/rnnt/utils.py
test/torchaudio_unittest/rnnt/utils.py
+86
-99
No files found.
test/torchaudio_unittest/rnnt/autograd_impl.py
View file @
0ea6d10d
...
...
@@ -8,10 +8,9 @@ from torchaudio_unittest.common_utils import (
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
,
rnnt_loss
from
parameterized
import
parameterized
from
.utils
import
(
numpy_to_torch
,
get_B1_T10_U3_D4_data
,
get_
numpy_data_
B2_T4_U3_D3
,
get_
numpy_data_
B1_T2_U3_D5
get_B2_T4_U3_D3
_data
,
get_B1_T2_U3_D5
_data
)
from
.numpy_transducer
import
NumpyTransducerLoss
...
...
@@ -19,12 +18,9 @@ from .numpy_transducer import NumpyTransducerLoss
class
Autograd
(
TestBaseMixin
):
@
staticmethod
def
get_data
(
data_func
,
device
):
data_np
=
data_func
()
if
type
(
data_np
)
==
tuple
:
data_np
=
data_np
[
0
]
data
=
numpy_to_torch
(
data
=
data_np
,
device
=
device
,
requires_grad
=
True
)
data
=
data_func
()
if
type
(
data
)
==
tuple
:
data
=
data
[
0
]
return
data
def
assert_grad
(
...
...
@@ -46,8 +42,8 @@ class Autograd(TestBaseMixin):
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_
numpy_data_
B2_T4_U3_D3
,
),
(
get_
numpy_data_
B1_T2_U3_D5
,
),
(
get_B2_T4_U3_D3
_data
,
),
(
get_B1_T2_U3_D5
_data
,
),
])
def
test_RNNTLoss_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
...
...
@@ -63,8 +59,8 @@ class Autograd(TestBaseMixin):
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_
numpy_data_
B2_T4_U3_D3
,
),
(
get_
numpy_data_
B1_T2_U3_D5
,
),
(
get_B2_T4_U3_D3
_data
,
),
(
get_B1_T2_U3_D5
_data
,
),
])
def
test_rnnt_loss_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
...
...
@@ -83,8 +79,8 @@ class Autograd(TestBaseMixin):
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_
numpy_data_
B2_T4_U3_D3
,
),
(
get_
numpy_data_
B1_T2_U3_D5
,
),
(
get_B2_T4_U3_D3
_data
,
),
(
get_B1_T2_U3_D5
_data
,
),
])
def
test_np_transducer_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
...
...
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
View file @
0ea6d10d
import
numpy
as
np
import
torch
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
from
.utils
import
(
compute_with_numpy_transducer
,
compute_with_pytorch_transducer
,
get_basic_data
,
get_B1_T10_U3_D4_data
,
get_data_basic
,
get_numpy_data_B1_T2_U3_D5
,
get_numpy_data_B2_T4_U3_D3
,
get_numpy_random_data
,
numpy_to_torch
,
get_B1_T2_U3_D5_data
,
get_B2_T4_U3_D3_data
,
get_random_data
,
)
...
...
@@ -23,42 +22,30 @@ class RNNTLossTest:
costs
,
gradients
=
compute_with_pytorch_transducer
(
data
=
data
,
reuse_logits_for_grads
=
reuse_logits_for_grads
)
np
.
testing
.
assert_allclose
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
logits_shape
,
gradients
.
shape
)
if
not
np
.
allclose
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
):
for
b
in
range
(
len
(
gradients
)):
T
=
data
[
"logit_lengths"
][
b
]
U
=
data
[
"target_lengths"
][
b
]
for
t
in
range
(
gradients
.
shape
[
1
]):
for
u
in
range
(
gradients
.
shape
[
2
]):
np
.
testing
.
assert_allclose
(
gradients
[
b
,
t
,
u
],
ref_gradients
[
b
,
t
,
u
],
atol
=
atol
,
rtol
=
rtol
,
err_msg
=
f
"failed on b=
{
b
}
, t=
{
t
}
/T=
{
T
}
, u=
{
u
}
/U=
{
U
}
"
,
)
self
.
assertEqual
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
)
def
test_basic_backward
(
self
):
rnnt_loss
=
RNNTLoss
()
logits
,
targets
,
logit_lengths
,
target_lengths
=
get_
data_
basic
(
self
.
device
)
logits
,
targets
,
logit_lengths
,
target_lengths
=
get_basic
_data
(
self
.
device
)
loss
=
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
loss
.
backward
()
def
test_costs_and_gradients_B1_T2_U3_D5_fp32
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B1_T2_U3_D5
(
dtype
=
np
.
float32
data
,
ref_costs
,
ref_gradients
=
get_B1_T2_U3_D5_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
def
test_costs_and_gradients_B1_T2_U3_D5_fp16
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B1_T2_U3_D5
(
dtype
=
np
.
float16
data
,
ref_costs
,
ref_gradients
=
get_B1_T2_U3_D5_data
(
dtype
=
torch
.
float16
,
device
=
self
.
device
,
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
...
...
@@ -68,19 +55,19 @@ class RNNTLossTest:
)
def
test_costs_and_gradients_B2_T4_U3_D3_fp32
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B2_T4_U3_D3
(
dtype
=
np
.
float32
data
,
ref_costs
,
ref_gradients
=
get_B2_T4_U3_D3_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
def
test_costs_and_gradients_B2_T4_U3_D3_fp16
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B2_T4_U3_D3
(
dtype
=
np
.
float16
data
,
ref_costs
,
ref_gradients
=
get_B2_T4_U3_D3_data
(
dtype
=
torch
.
float16
,
device
=
self
.
device
,
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
...
...
@@ -92,8 +79,7 @@ class RNNTLossTest:
def
test_costs_and_gradients_random_data_with_numpy_fp32
(
self
):
seed
=
777
for
i
in
range
(
5
):
data
=
get_numpy_random_data
(
dtype
=
np
.
float32
,
seed
=
(
seed
+
i
))
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
data
=
get_random_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
seed
=
(
seed
+
i
))
ref_costs
,
ref_gradients
=
compute_with_numpy_transducer
(
data
=
data
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
...
...
@@ -103,9 +89,8 @@ class RNNTLossTest:
for
random
in
[
False
,
True
]:
data
=
get_B1_T10_U3_D4_data
(
random
=
random
,
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
data
[
"fused_log_softmax"
]
=
False
ref_costs
,
ref_gradients
=
compute_with_numpy_transducer
(
...
...
test/torchaudio_unittest/rnnt/utils.py
View file @
0ea6d10d
import
unittest
import
numpy
as
np
import
random
import
torch
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
...
...
@@ -19,10 +18,8 @@ def compute_with_numpy_transducer(data):
loss
=
torch
.
sum
(
costs
)
loss
.
backward
()
costs
=
costs
.
cpu
().
data
.
numpy
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
().
data
.
numpy
()
costs
=
costs
.
cpu
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
()
return
costs
,
gradients
...
...
@@ -41,12 +38,12 @@ def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
loss
=
torch
.
sum
(
costs
)
loss
.
backward
()
costs
=
costs
.
cpu
()
.
data
.
numpy
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
()
.
data
.
numpy
()
costs
=
costs
.
cpu
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
()
return
costs
,
gradients
def
get_
data_
basic
(
device
):
def
get_basic
_data
(
device
):
# Example provided
# in 6f73a2513dc784c59eec153a45f40bc528355b18
# of https://github.com/HawkAaron/warp-transducer
...
...
@@ -66,16 +63,12 @@ def get_data_basic(device):
],
]
],
dtype
=
torch
.
float
,
dtype
=
torch
.
float32
,
device
=
device
,
)
targets
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int
)
logit_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
)
target_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
)
logits
=
logits
.
to
(
device
=
device
)
targets
=
targets
.
to
(
device
=
device
)
logit_lengths
=
logit_lengths
.
to
(
device
=
device
)
target_lengths
=
target_lengths
.
to
(
device
=
device
)
targets
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int
,
device
=
device
)
logit_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
,
device
=
device
)
target_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
,
device
=
device
)
logits
.
requires_grad_
(
True
)
...
...
@@ -84,27 +77,32 @@ def get_data_basic(device):
def
get_B1_T10_U3_D4_data
(
random
=
False
,
dtype
=
np
.
float32
,
nan
=
False
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
)
,
):
B
,
T
,
U
,
D
=
2
,
10
,
3
,
4
data
=
{}
data
[
"
logits
"
]
=
np
.
random
.
rand
(
B
,
T
,
U
,
D
).
as
type
(
dtype
)
logits
=
torch
.
rand
(
B
,
T
,
U
,
D
,
d
type
=
dtype
,
device
=
device
)
if
not
random
:
data
[
"logits"
].
fill
(
0.1
)
if
nan
:
for
i
in
range
(
B
):
data
[
"logits"
][
i
][
0
][
0
][
0
]
=
np
.
nan
data
[
"logit_lengths"
]
=
np
.
array
([
10
,
10
],
dtype
=
np
.
int32
)
data
[
"target_lengths"
]
=
np
.
array
([
2
,
2
],
dtype
=
np
.
int32
)
data
[
"targets"
]
=
np
.
array
([[
1
,
2
],
[
1
,
2
]],
dtype
=
np
.
int32
)
logits
.
fill_
(
0.1
)
logits
.
requires_grad_
(
True
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
data
=
{}
data
[
"logits"
]
=
logits
data
[
"logit_lengths"
]
=
torch
.
tensor
([
10
,
10
],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"target_lengths"
]
=
torch
.
tensor
([
2
,
2
],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"targets"
]
=
torch
.
tensor
([[
1
,
2
],
[
1
,
2
]],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"blank"
]
=
0
return
data
def
get_
numpy_data_
B1_T2_U3_D5
(
dtype
=
np
.
float32
):
logits
=
np
.
array
(
def
get_B1_T2_U3_D5
_data
(
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
)
):
logits
=
torch
.
tensor
(
[
0.1
,
0.6
,
...
...
@@ -138,15 +136,22 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
0.1
,
],
dtype
=
dtype
,
device
=
device
,
).
reshape
(
1
,
2
,
3
,
5
)
targets
=
np
.
array
([[
1
,
2
]],
dtype
=
np
.
int32
)
logit_lengths
=
np
.
array
([
2
],
dtype
=
np
.
int32
)
target_lengths
=
np
.
array
([
2
],
dtype
=
np
.
int32
)
logits
.
requires_grad_
(
True
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
targets
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int32
,
device
=
device
)
logit_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int32
,
device
=
device
)
target_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int32
,
device
=
device
)
blank
=
-
1
ref_costs
=
np
.
array
([
5.09566688538
],
dtype
=
dtype
)
ref_gradients
=
np
.
array
(
ref_costs
=
torch
.
tensor
([
5.09566688538
],
dtype
=
dtype
)
ref_gradients
=
torch
.
tensor
(
[
0.17703132
,
-
0.39992708
,
...
...
@@ -193,10 +198,9 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
return
data
,
ref_costs
,
ref_gradients
def
get_
numpy_data_
B2_T4_U3_D3
(
dtype
=
np
.
float32
):
def
get_B2_T4_U3_D3
_data
(
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
)
):
# Test from D21322854
logits
=
np
.
array
(
logits
=
torch
.
tensor
(
[
0.065357
,
0.787530
,
...
...
@@ -272,17 +276,23 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
0.358021
,
],
dtype
=
dtype
,
device
=
device
,
).
reshape
(
2
,
4
,
3
,
3
)
logits
.
requires_grad_
(
True
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
targets
=
np
.
array
([[
1
,
2
],
[
1
,
1
]],
dtype
=
np
.
int32
)
logit_lengths
=
np
.
array
([
4
,
4
],
dtype
=
np
.
int32
)
target_lengths
=
np
.
array
([
2
,
2
],
dtype
=
np
.
int32
)
targets
=
torch
.
tensor
([[
1
,
2
],
[
1
,
1
]],
dtype
=
torch
.
int32
,
device
=
device
)
logit_lengths
=
torch
.
tensor
([
4
,
4
],
dtype
=
torch
.
int32
,
device
=
device
)
target_lengths
=
torch
.
tensor
([
2
,
2
],
dtype
=
torch
.
int32
,
device
=
device
)
blank
=
0
ref_costs
=
np
.
array
([
4.2806528590890736
,
3.9384369822503591
],
dtype
=
dtype
)
ref_costs
=
torch
.
tensor
([
4.2806528590890736
,
3.9384369822503591
],
dtype
=
dtype
)
ref_gradients
=
np
.
array
(
ref_gradients
=
torch
.
tensor
(
[
-
0.186844
,
-
0.062555
,
...
...
@@ -371,30 +381,45 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
return
data
,
ref_costs
,
ref_gradients
def
get_numpy_random_data
(
max_B
=
8
,
max_T
=
128
,
max_U
=
32
,
max_D
=
40
,
blank
=-
1
,
dtype
=
np
.
float32
,
seed
=
None
def
get_random_data
(
max_B
=
8
,
max_T
=
128
,
max_U
=
32
,
max_D
=
40
,
blank
=-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
),
seed
=
None
,
):
if
seed
is
not
None
:
np
.
random
.
seed
(
seed
=
seed
)
torch
.
manual_
seed
(
seed
=
seed
)
if
blank
!=
-
1
:
raise
ValueError
(
"blank != -1 is not supported yet."
)
B
=
np
.
random
.
randint
(
low
=
1
,
high
=
max_B
)
T
=
np
.
random
.
randint
(
low
=
5
,
high
=
max_T
)
U
=
np
.
random
.
randint
(
low
=
5
,
high
=
max_U
)
D
=
np
.
random
.
randint
(
low
=
2
,
high
=
max_D
)
logit_lengths
=
np
.
random
.
randint
(
low
=
5
,
high
=
T
+
1
,
size
=
(
B
,),
dtype
=
np
.
int32
)
target_lengths
=
np
.
random
.
randint
(
low
=
5
,
high
=
U
+
1
,
size
=
(
B
,),
dtype
=
np
.
int32
)
max_src_length
=
np
.
max
(
logit_lengths
)
max_tgt_length
=
np
.
max
(
target_lengths
)
targets
=
np
.
random
.
randint
(
low
=
0
,
high
=
D
-
1
,
size
=
(
B
,
max_tgt_length
),
dtype
=
np
.
int32
random
.
seed
(
0
)
B
=
random
.
randint
(
1
,
max_B
-
1
)
T
=
random
.
randint
(
5
,
max_T
-
1
)
U
=
random
.
randint
(
5
,
max_U
-
1
)
D
=
random
.
randint
(
2
,
max_D
-
1
)
logit_lengths
=
torch
.
randint
(
low
=
5
,
high
=
T
+
1
,
size
=
(
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
target_lengths
=
torch
.
randint
(
low
=
5
,
high
=
U
+
1
,
size
=
(
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
max_src_length
=
torch
.
max
(
logit_lengths
)
max_tgt_length
=
torch
.
max
(
target_lengths
)
targets
=
torch
.
randint
(
low
=
0
,
high
=
D
-
1
,
size
=
(
B
,
max_tgt_length
),
dtype
=
torch
.
int32
,
device
=
device
)
logits
=
np
.
random
.
random_sample
(
size
=
(
B
,
max_src_length
,
max_tgt_length
+
1
,
D
)
).
astype
(
dtype
=
dtype
)
logits
=
torch
.
rand
(
size
=
(
B
,
max_src_length
,
max_tgt_length
+
1
,
D
),
dtype
=
dtype
,
device
=
device
,
).
requires_grad_
(
True
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
return
{
"logits"
:
logits
,
...
...
@@ -405,44 +430,6 @@ def get_numpy_random_data(
}
def
numpy_to_torch
(
data
,
device
,
requires_grad
=
True
):
logits
=
torch
.
from_numpy
(
data
[
"logits"
]).
to
(
device
=
device
)
targets
=
torch
.
from_numpy
(
data
[
"targets"
]).
to
(
device
=
device
)
logit_lengths
=
torch
.
from_numpy
(
data
[
"logit_lengths"
]).
to
(
device
=
device
)
target_lengths
=
torch
.
from_numpy
(
data
[
"target_lengths"
]).
to
(
device
=
device
)
if
"nbest_wers"
in
data
:
data
[
"nbest_wers"
]
=
torch
.
from_numpy
(
data
[
"nbest_wers"
]).
to
(
device
=
device
)
if
"nbest_scores"
in
data
:
data
[
"nbest_scores"
]
=
torch
.
from_numpy
(
data
[
"nbest_scores"
]).
to
(
device
=
device
)
logits
=
torch
.
autograd
.
Variable
(
logits
,
requires_grad
=
requires_grad
)
logit_lengths
=
torch
.
autograd
.
Variable
(
logit_lengths
)
target_lengths
=
torch
.
autograd
.
Variable
(
target_lengths
)
targets
=
torch
.
autograd
.
Variable
(
targets
)
if
device
==
torch
.
device
(
"cpu"
):
logits
=
logits
.
cpu
()
elif
device
==
torch
.
device
(
"cuda"
):
logits
=
logits
.
cuda
()
else
:
raise
ValueError
(
"unrecognized device = {}"
.
format
(
device
))
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
data
[
"logits"
]
=
logits
data
[
"logit_lengths"
]
=
logit_lengths
data
[
"target_lengths"
]
=
target_lengths
data
[
"targets"
]
=
targets
return
data
def
skipIfNoRNNT
(
test_item
):
try
:
torch
.
ops
.
torchaudio
.
rnnt_loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment