Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
0ea6d10d
"git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "916b9a5e65fc176275df89aeb460897a4d450232"
Unverified
Commit
0ea6d10d
authored
Jul 15, 2021
by
Caroline Chen
Committed by
GitHub
Jul 15, 2021
Browse files
Refactor RNNT Loss Unit Tests (#1630)
parent
56ab0368
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
120 additions
and
152 deletions
+120
-152
test/torchaudio_unittest/rnnt/autograd_impl.py
test/torchaudio_unittest/rnnt/autograd_impl.py
+11
-15
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
+23
-38
test/torchaudio_unittest/rnnt/utils.py
test/torchaudio_unittest/rnnt/utils.py
+86
-99
No files found.
test/torchaudio_unittest/rnnt/autograd_impl.py
View file @
0ea6d10d
...
...
@@ -8,10 +8,9 @@ from torchaudio_unittest.common_utils import (
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
,
rnnt_loss
from
parameterized
import
parameterized
from
.utils
import
(
numpy_to_torch
,
get_B1_T10_U3_D4_data
,
get_
numpy_data_
B2_T4_U3_D3
,
get_
numpy_data_
B1_T2_U3_D5
get_B2_T4_U3_D3
_data
,
get_B1_T2_U3_D5
_data
)
from
.numpy_transducer
import
NumpyTransducerLoss
...
...
@@ -19,12 +18,9 @@ from .numpy_transducer import NumpyTransducerLoss
class
Autograd
(
TestBaseMixin
):
@
staticmethod
def
get_data
(
data_func
,
device
):
data_np
=
data_func
()
if
type
(
data_np
)
==
tuple
:
data_np
=
data_np
[
0
]
data
=
numpy_to_torch
(
data
=
data_np
,
device
=
device
,
requires_grad
=
True
)
data
=
data_func
()
if
type
(
data
)
==
tuple
:
data
=
data
[
0
]
return
data
def
assert_grad
(
...
...
@@ -46,8 +42,8 @@ class Autograd(TestBaseMixin):
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_
numpy_data_
B2_T4_U3_D3
,
),
(
get_
numpy_data_
B1_T2_U3_D5
,
),
(
get_B2_T4_U3_D3
_data
,
),
(
get_B1_T2_U3_D5
_data
,
),
])
def
test_RNNTLoss_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
...
...
@@ -63,8 +59,8 @@ class Autograd(TestBaseMixin):
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_
numpy_data_
B2_T4_U3_D3
,
),
(
get_
numpy_data_
B1_T2_U3_D5
,
),
(
get_B2_T4_U3_D3
_data
,
),
(
get_B1_T2_U3_D5
_data
,
),
])
def
test_rnnt_loss_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
...
...
@@ -83,8 +79,8 @@ class Autograd(TestBaseMixin):
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_
numpy_data_
B2_T4_U3_D3
,
),
(
get_
numpy_data_
B1_T2_U3_D5
,
),
(
get_B2_T4_U3_D3
_data
,
),
(
get_B1_T2_U3_D5
_data
,
),
])
def
test_np_transducer_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
...
...
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
View file @
0ea6d10d
import
numpy
as
np
import
torch
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
from
.utils
import
(
compute_with_numpy_transducer
,
compute_with_pytorch_transducer
,
get_basic_data
,
get_B1_T10_U3_D4_data
,
get_data_basic
,
get_numpy_data_B1_T2_U3_D5
,
get_numpy_data_B2_T4_U3_D3
,
get_numpy_random_data
,
numpy_to_torch
,
get_B1_T2_U3_D5_data
,
get_B2_T4_U3_D3_data
,
get_random_data
,
)
...
...
@@ -23,42 +22,30 @@ class RNNTLossTest:
costs
,
gradients
=
compute_with_pytorch_transducer
(
data
=
data
,
reuse_logits_for_grads
=
reuse_logits_for_grads
)
np
.
testing
.
assert_allclose
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
logits_shape
,
gradients
.
shape
)
if
not
np
.
allclose
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
):
for
b
in
range
(
len
(
gradients
)):
T
=
data
[
"logit_lengths"
][
b
]
U
=
data
[
"target_lengths"
][
b
]
for
t
in
range
(
gradients
.
shape
[
1
]):
for
u
in
range
(
gradients
.
shape
[
2
]):
np
.
testing
.
assert_allclose
(
gradients
[
b
,
t
,
u
],
ref_gradients
[
b
,
t
,
u
],
atol
=
atol
,
rtol
=
rtol
,
err_msg
=
f
"failed on b=
{
b
}
, t=
{
t
}
/T=
{
T
}
, u=
{
u
}
/U=
{
U
}
"
,
)
self
.
assertEqual
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
)
def
test_basic_backward
(
self
):
rnnt_loss
=
RNNTLoss
()
logits
,
targets
,
logit_lengths
,
target_lengths
=
get_
data_
basic
(
self
.
device
)
logits
,
targets
,
logit_lengths
,
target_lengths
=
get_basic
_data
(
self
.
device
)
loss
=
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
loss
.
backward
()
def
test_costs_and_gradients_B1_T2_U3_D5_fp32
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B1_T2_U3_D5
(
dtype
=
np
.
float32
data
,
ref_costs
,
ref_gradients
=
get_B1_T2_U3_D5_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
def
test_costs_and_gradients_B1_T2_U3_D5_fp16
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B1_T2_U3_D5
(
dtype
=
np
.
float16
data
,
ref_costs
,
ref_gradients
=
get_B1_T2_U3_D5_data
(
dtype
=
torch
.
float16
,
device
=
self
.
device
,
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
...
...
@@ -68,19 +55,19 @@ class RNNTLossTest:
)
def
test_costs_and_gradients_B2_T4_U3_D3_fp32
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B2_T4_U3_D3
(
dtype
=
np
.
float32
data
,
ref_costs
,
ref_gradients
=
get_B2_T4_U3_D3_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
def
test_costs_and_gradients_B2_T4_U3_D3_fp16
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B2_T4_U3_D3
(
dtype
=
np
.
float16
data
,
ref_costs
,
ref_gradients
=
get_B2_T4_U3_D3_data
(
dtype
=
torch
.
float16
,
device
=
self
.
device
,
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
...
...
@@ -92,8 +79,7 @@ class RNNTLossTest:
def
test_costs_and_gradients_random_data_with_numpy_fp32
(
self
):
seed
=
777
for
i
in
range
(
5
):
data
=
get_numpy_random_data
(
dtype
=
np
.
float32
,
seed
=
(
seed
+
i
))
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
data
=
get_random_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
seed
=
(
seed
+
i
))
ref_costs
,
ref_gradients
=
compute_with_numpy_transducer
(
data
=
data
)
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
...
...
@@ -103,9 +89,8 @@ class RNNTLossTest:
for
random
in
[
False
,
True
]:
data
=
get_B1_T10_U3_D4_data
(
random
=
random
,
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
data
[
"fused_log_softmax"
]
=
False
ref_costs
,
ref_gradients
=
compute_with_numpy_transducer
(
...
...
test/torchaudio_unittest/rnnt/utils.py
View file @
0ea6d10d
import
unittest
import
numpy
as
np
import
random
import
torch
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
...
...
@@ -19,10 +18,8 @@ def compute_with_numpy_transducer(data):
loss
=
torch
.
sum
(
costs
)
loss
.
backward
()
costs
=
costs
.
cpu
().
data
.
numpy
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
().
data
.
numpy
()
costs
=
costs
.
cpu
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
()
return
costs
,
gradients
...
...
@@ -41,12 +38,12 @@ def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
loss
=
torch
.
sum
(
costs
)
loss
.
backward
()
costs
=
costs
.
cpu
()
.
data
.
numpy
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
()
.
data
.
numpy
()
costs
=
costs
.
cpu
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
()
return
costs
,
gradients
def
get_
data_
basic
(
device
):
def
get_basic
_data
(
device
):
# Example provided
# in 6f73a2513dc784c59eec153a45f40bc528355b18
# of https://github.com/HawkAaron/warp-transducer
...
...
@@ -66,16 +63,12 @@ def get_data_basic(device):
],
]
],
dtype
=
torch
.
float
,
dtype
=
torch
.
float32
,
device
=
device
,
)
targets
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int
)
logit_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
)
target_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
)
logits
=
logits
.
to
(
device
=
device
)
targets
=
targets
.
to
(
device
=
device
)
logit_lengths
=
logit_lengths
.
to
(
device
=
device
)
target_lengths
=
target_lengths
.
to
(
device
=
device
)
targets
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int
,
device
=
device
)
logit_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
,
device
=
device
)
target_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
,
device
=
device
)
logits
.
requires_grad_
(
True
)
...
...
@@ -84,27 +77,32 @@ def get_data_basic(device):
def
get_B1_T10_U3_D4_data
(
random
=
False
,
dtype
=
np
.
float32
,
nan
=
False
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
)
,
):
B
,
T
,
U
,
D
=
2
,
10
,
3
,
4
data
=
{}
data
[
"
logits
"
]
=
np
.
random
.
rand
(
B
,
T
,
U
,
D
).
as
type
(
dtype
)
logits
=
torch
.
rand
(
B
,
T
,
U
,
D
,
d
type
=
dtype
,
device
=
device
)
if
not
random
:
data
[
"logits"
].
fill
(
0.1
)
if
nan
:
for
i
in
range
(
B
):
data
[
"logits"
][
i
][
0
][
0
][
0
]
=
np
.
nan
data
[
"logit_lengths"
]
=
np
.
array
([
10
,
10
],
dtype
=
np
.
int32
)
data
[
"target_lengths"
]
=
np
.
array
([
2
,
2
],
dtype
=
np
.
int32
)
data
[
"targets"
]
=
np
.
array
([[
1
,
2
],
[
1
,
2
]],
dtype
=
np
.
int32
)
logits
.
fill_
(
0.1
)
logits
.
requires_grad_
(
True
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
data
=
{}
data
[
"logits"
]
=
logits
data
[
"logit_lengths"
]
=
torch
.
tensor
([
10
,
10
],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"target_lengths"
]
=
torch
.
tensor
([
2
,
2
],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"targets"
]
=
torch
.
tensor
([[
1
,
2
],
[
1
,
2
]],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"blank"
]
=
0
return
data
def
get_
numpy_data_
B1_T2_U3_D5
(
dtype
=
np
.
float32
):
logits
=
np
.
array
(
def
get_B1_T2_U3_D5
_data
(
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
)
):
logits
=
torch
.
tensor
(
[
0.1
,
0.6
,
...
...
@@ -138,15 +136,22 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
0.1
,
],
dtype
=
dtype
,
device
=
device
,
).
reshape
(
1
,
2
,
3
,
5
)
targets
=
np
.
array
([[
1
,
2
]],
dtype
=
np
.
int32
)
logit_lengths
=
np
.
array
([
2
],
dtype
=
np
.
int32
)
target_lengths
=
np
.
array
([
2
],
dtype
=
np
.
int32
)
logits
.
requires_grad_
(
True
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
targets
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int32
,
device
=
device
)
logit_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int32
,
device
=
device
)
target_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int32
,
device
=
device
)
blank
=
-
1
ref_costs
=
np
.
array
([
5.09566688538
],
dtype
=
dtype
)
ref_gradients
=
np
.
array
(
ref_costs
=
torch
.
tensor
([
5.09566688538
],
dtype
=
dtype
)
ref_gradients
=
torch
.
tensor
(
[
0.17703132
,
-
0.39992708
,
...
...
@@ -193,10 +198,9 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
return
data
,
ref_costs
,
ref_gradients
def
get_
numpy_data_
B2_T4_U3_D3
(
dtype
=
np
.
float32
):
def
get_B2_T4_U3_D3
_data
(
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
)
):
# Test from D21322854
logits
=
np
.
array
(
logits
=
torch
.
tensor
(
[
0.065357
,
0.787530
,
...
...
@@ -272,17 +276,23 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
0.358021
,
],
dtype
=
dtype
,
device
=
device
,
).
reshape
(
2
,
4
,
3
,
3
)
logits
.
requires_grad_
(
True
)
targets
=
np
.
array
([[
1
,
2
],
[
1
,
1
]],
dtype
=
np
.
int32
)
logit_lengths
=
np
.
array
([
4
,
4
],
dtype
=
np
.
int32
)
target_lengths
=
np
.
array
([
2
,
2
],
dtype
=
np
.
int32
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
targets
=
torch
.
tensor
([[
1
,
2
],
[
1
,
1
]],
dtype
=
torch
.
int32
,
device
=
device
)
logit_lengths
=
torch
.
tensor
([
4
,
4
],
dtype
=
torch
.
int32
,
device
=
device
)
target_lengths
=
torch
.
tensor
([
2
,
2
],
dtype
=
torch
.
int32
,
device
=
device
)
blank
=
0
ref_costs
=
np
.
array
([
4.2806528590890736
,
3.9384369822503591
],
dtype
=
dtype
)
ref_costs
=
torch
.
tensor
([
4.2806528590890736
,
3.9384369822503591
],
dtype
=
dtype
)
ref_gradients
=
np
.
array
(
ref_gradients
=
torch
.
tensor
(
[
-
0.186844
,
-
0.062555
,
...
...
@@ -371,30 +381,45 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
return
data
,
ref_costs
,
ref_gradients
def
get_numpy_random_data
(
max_B
=
8
,
max_T
=
128
,
max_U
=
32
,
max_D
=
40
,
blank
=-
1
,
dtype
=
np
.
float32
,
seed
=
None
def
get_random_data
(
max_B
=
8
,
max_T
=
128
,
max_U
=
32
,
max_D
=
40
,
blank
=-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
),
seed
=
None
,
):
if
seed
is
not
None
:
np
.
random
.
seed
(
seed
=
seed
)
torch
.
manual_
seed
(
seed
=
seed
)
if
blank
!=
-
1
:
raise
ValueError
(
"blank != -1 is not supported yet."
)
B
=
np
.
random
.
randint
(
low
=
1
,
high
=
max_B
)
T
=
np
.
random
.
randint
(
low
=
5
,
high
=
max_T
)
U
=
np
.
random
.
randint
(
low
=
5
,
high
=
max_U
)
D
=
np
.
random
.
randint
(
low
=
2
,
high
=
max_D
)
logit_lengths
=
np
.
random
.
randint
(
low
=
5
,
high
=
T
+
1
,
size
=
(
B
,),
dtype
=
np
.
int32
)
target_lengths
=
np
.
random
.
randint
(
low
=
5
,
high
=
U
+
1
,
size
=
(
B
,),
dtype
=
np
.
int32
)
max_src_length
=
np
.
max
(
logit_lengths
)
max_tgt_length
=
np
.
max
(
target_lengths
)
targets
=
np
.
random
.
randint
(
low
=
0
,
high
=
D
-
1
,
size
=
(
B
,
max_tgt_length
),
dtype
=
np
.
int32
random
.
seed
(
0
)
B
=
random
.
randint
(
1
,
max_B
-
1
)
T
=
random
.
randint
(
5
,
max_T
-
1
)
U
=
random
.
randint
(
5
,
max_U
-
1
)
D
=
random
.
randint
(
2
,
max_D
-
1
)
logit_lengths
=
torch
.
randint
(
low
=
5
,
high
=
T
+
1
,
size
=
(
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
target_lengths
=
torch
.
randint
(
low
=
5
,
high
=
U
+
1
,
size
=
(
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
max_src_length
=
torch
.
max
(
logit_lengths
)
max_tgt_length
=
torch
.
max
(
target_lengths
)
targets
=
torch
.
randint
(
low
=
0
,
high
=
D
-
1
,
size
=
(
B
,
max_tgt_length
),
dtype
=
torch
.
int32
,
device
=
device
)
logits
=
np
.
random
.
random_sample
(
size
=
(
B
,
max_src_length
,
max_tgt_length
+
1
,
D
)
).
astype
(
dtype
=
dtype
)
logits
=
torch
.
rand
(
size
=
(
B
,
max_src_length
,
max_tgt_length
+
1
,
D
),
dtype
=
dtype
,
device
=
device
,
).
requires_grad_
(
True
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
return
{
"logits"
:
logits
,
...
...
@@ -405,44 +430,6 @@ def get_numpy_random_data(
}
def
numpy_to_torch
(
data
,
device
,
requires_grad
=
True
):
logits
=
torch
.
from_numpy
(
data
[
"logits"
]).
to
(
device
=
device
)
targets
=
torch
.
from_numpy
(
data
[
"targets"
]).
to
(
device
=
device
)
logit_lengths
=
torch
.
from_numpy
(
data
[
"logit_lengths"
]).
to
(
device
=
device
)
target_lengths
=
torch
.
from_numpy
(
data
[
"target_lengths"
]).
to
(
device
=
device
)
if
"nbest_wers"
in
data
:
data
[
"nbest_wers"
]
=
torch
.
from_numpy
(
data
[
"nbest_wers"
]).
to
(
device
=
device
)
if
"nbest_scores"
in
data
:
data
[
"nbest_scores"
]
=
torch
.
from_numpy
(
data
[
"nbest_scores"
]).
to
(
device
=
device
)
logits
=
torch
.
autograd
.
Variable
(
logits
,
requires_grad
=
requires_grad
)
logit_lengths
=
torch
.
autograd
.
Variable
(
logit_lengths
)
target_lengths
=
torch
.
autograd
.
Variable
(
target_lengths
)
targets
=
torch
.
autograd
.
Variable
(
targets
)
if
device
==
torch
.
device
(
"cpu"
):
logits
=
logits
.
cpu
()
elif
device
==
torch
.
device
(
"cuda"
):
logits
=
logits
.
cuda
()
else
:
raise
ValueError
(
"unrecognized device = {}"
.
format
(
device
))
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
data
[
"logits"
]
=
logits
data
[
"logit_lengths"
]
=
logit_lengths
data
[
"target_lengths"
]
=
target_lengths
data
[
"targets"
]
=
targets
return
data
def
skipIfNoRNNT
(
test_item
):
try
:
torch
.
ops
.
torchaudio
.
rnnt_loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment