Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
0ea6d10d
Unverified
Commit
0ea6d10d
authored
Jul 15, 2021
by
Caroline Chen
Committed by
GitHub
Jul 15, 2021
Browse files
Refactor RNNT Loss Unit Tests (#1630)
parent
56ab0368
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
120 additions
and
152 deletions
+120
-152
test/torchaudio_unittest/rnnt/autograd_impl.py
test/torchaudio_unittest/rnnt/autograd_impl.py
+11
-15
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
+23
-38
test/torchaudio_unittest/rnnt/utils.py
test/torchaudio_unittest/rnnt/utils.py
+86
-99
No files found.
test/torchaudio_unittest/rnnt/autograd_impl.py
View file @
0ea6d10d
...
@@ -8,10 +8,9 @@ from torchaudio_unittest.common_utils import (
...
@@ -8,10 +8,9 @@ from torchaudio_unittest.common_utils import (
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
,
rnnt_loss
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
,
rnnt_loss
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
.utils
import
(
from
.utils
import
(
numpy_to_torch
,
get_B1_T10_U3_D4_data
,
get_B1_T10_U3_D4_data
,
get_
numpy_data_
B2_T4_U3_D3
,
get_B2_T4_U3_D3
_data
,
get_
numpy_data_
B1_T2_U3_D5
get_B1_T2_U3_D5
_data
)
)
from
.numpy_transducer
import
NumpyTransducerLoss
from
.numpy_transducer
import
NumpyTransducerLoss
...
@@ -19,12 +18,9 @@ from .numpy_transducer import NumpyTransducerLoss
...
@@ -19,12 +18,9 @@ from .numpy_transducer import NumpyTransducerLoss
class
Autograd
(
TestBaseMixin
):
class
Autograd
(
TestBaseMixin
):
@
staticmethod
@
staticmethod
def
get_data
(
data_func
,
device
):
def
get_data
(
data_func
,
device
):
data_np
=
data_func
()
data
=
data_func
()
if
type
(
data_np
)
==
tuple
:
if
type
(
data
)
==
tuple
:
data_np
=
data_np
[
0
]
data
=
data
[
0
]
data
=
numpy_to_torch
(
data
=
data_np
,
device
=
device
,
requires_grad
=
True
)
return
data
return
data
def
assert_grad
(
def
assert_grad
(
...
@@ -46,8 +42,8 @@ class Autograd(TestBaseMixin):
...
@@ -46,8 +42,8 @@ class Autograd(TestBaseMixin):
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_B1_T10_U3_D4_data
,
),
(
get_
numpy_data_
B2_T4_U3_D3
,
),
(
get_B2_T4_U3_D3
_data
,
),
(
get_
numpy_data_
B1_T2_U3_D5
,
),
(
get_B1_T2_U3_D5
_data
,
),
])
])
def
test_RNNTLoss_gradcheck
(
self
,
data_func
):
def
test_RNNTLoss_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
data
=
self
.
get_data
(
data_func
,
self
.
device
)
...
@@ -63,8 +59,8 @@ class Autograd(TestBaseMixin):
...
@@ -63,8 +59,8 @@ class Autograd(TestBaseMixin):
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_B1_T10_U3_D4_data
,
),
(
get_
numpy_data_
B2_T4_U3_D3
,
),
(
get_B2_T4_U3_D3
_data
,
),
(
get_
numpy_data_
B1_T2_U3_D5
,
),
(
get_B1_T2_U3_D5
_data
,
),
])
])
def
test_rnnt_loss_gradcheck
(
self
,
data_func
):
def
test_rnnt_loss_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
data
=
self
.
get_data
(
data_func
,
self
.
device
)
...
@@ -83,8 +79,8 @@ class Autograd(TestBaseMixin):
...
@@ -83,8 +79,8 @@ class Autograd(TestBaseMixin):
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
get_B1_T10_U3_D4_data
,
),
(
get_B1_T10_U3_D4_data
,
),
(
get_
numpy_data_
B2_T4_U3_D3
,
),
(
get_B2_T4_U3_D3
_data
,
),
(
get_
numpy_data_
B1_T2_U3_D5
,
),
(
get_B1_T2_U3_D5
_data
,
),
])
])
def
test_np_transducer_gradcheck
(
self
,
data_func
):
def
test_np_transducer_gradcheck
(
self
,
data_func
):
data
=
self
.
get_data
(
data_func
,
self
.
device
)
data
=
self
.
get_data
(
data_func
,
self
.
device
)
...
...
test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
View file @
0ea6d10d
import
numpy
as
np
import
torch
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
from
.utils
import
(
from
.utils
import
(
compute_with_numpy_transducer
,
compute_with_numpy_transducer
,
compute_with_pytorch_transducer
,
compute_with_pytorch_transducer
,
get_basic_data
,
get_B1_T10_U3_D4_data
,
get_B1_T10_U3_D4_data
,
get_data_basic
,
get_B1_T2_U3_D5_data
,
get_numpy_data_B1_T2_U3_D5
,
get_B2_T4_U3_D3_data
,
get_numpy_data_B2_T4_U3_D3
,
get_random_data
,
get_numpy_random_data
,
numpy_to_torch
,
)
)
...
@@ -23,42 +22,30 @@ class RNNTLossTest:
...
@@ -23,42 +22,30 @@ class RNNTLossTest:
costs
,
gradients
=
compute_with_pytorch_transducer
(
costs
,
gradients
=
compute_with_pytorch_transducer
(
data
=
data
,
reuse_logits_for_grads
=
reuse_logits_for_grads
data
=
data
,
reuse_logits_for_grads
=
reuse_logits_for_grads
)
)
np
.
testing
.
assert_allclose
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
logits_shape
,
gradients
.
shape
)
self
.
assertEqual
(
logits_shape
,
gradients
.
shape
)
if
not
np
.
allclose
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
):
self
.
assertEqual
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
)
for
b
in
range
(
len
(
gradients
)):
T
=
data
[
"logit_lengths"
][
b
]
U
=
data
[
"target_lengths"
][
b
]
for
t
in
range
(
gradients
.
shape
[
1
]):
for
u
in
range
(
gradients
.
shape
[
2
]):
np
.
testing
.
assert_allclose
(
gradients
[
b
,
t
,
u
],
ref_gradients
[
b
,
t
,
u
],
atol
=
atol
,
rtol
=
rtol
,
err_msg
=
f
"failed on b=
{
b
}
, t=
{
t
}
/T=
{
T
}
, u=
{
u
}
/U=
{
U
}
"
,
)
def
test_basic_backward
(
self
):
def
test_basic_backward
(
self
):
rnnt_loss
=
RNNTLoss
()
rnnt_loss
=
RNNTLoss
()
logits
,
targets
,
logit_lengths
,
target_lengths
=
get_
data_
basic
(
self
.
device
)
logits
,
targets
,
logit_lengths
,
target_lengths
=
get_basic
_data
(
self
.
device
)
loss
=
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
loss
=
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
)
loss
.
backward
()
loss
.
backward
()
def
test_costs_and_gradients_B1_T2_U3_D5_fp32
(
self
):
def
test_costs_and_gradients_B1_T2_U3_D5_fp32
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B1_T2_U3_D5
(
data
,
ref_costs
,
ref_gradients
=
get_B1_T2_U3_D5_data
(
dtype
=
np
.
float32
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
)
def
test_costs_and_gradients_B1_T2_U3_D5_fp16
(
self
):
def
test_costs_and_gradients_B1_T2_U3_D5_fp16
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B1_T2_U3_D5
(
data
,
ref_costs
,
ref_gradients
=
get_B1_T2_U3_D5_data
(
dtype
=
np
.
float16
dtype
=
torch
.
float16
,
device
=
self
.
device
,
)
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
self
.
_test_costs_and_gradients
(
data
=
data
,
data
=
data
,
ref_costs
=
ref_costs
,
ref_costs
=
ref_costs
,
...
@@ -68,19 +55,19 @@ class RNNTLossTest:
...
@@ -68,19 +55,19 @@ class RNNTLossTest:
)
)
def
test_costs_and_gradients_B2_T4_U3_D3_fp32
(
self
):
def
test_costs_and_gradients_B2_T4_U3_D3_fp32
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B2_T4_U3_D3
(
data
,
ref_costs
,
ref_gradients
=
get_B2_T4_U3_D3_data
(
dtype
=
np
.
float32
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
)
)
def
test_costs_and_gradients_B2_T4_U3_D3_fp16
(
self
):
def
test_costs_and_gradients_B2_T4_U3_D3_fp16
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_numpy_data_B2_T4_U3_D3
(
data
,
ref_costs
,
ref_gradients
=
get_B2_T4_U3_D3_data
(
dtype
=
np
.
float16
dtype
=
torch
.
float16
,
device
=
self
.
device
,
)
)
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
self
.
_test_costs_and_gradients
(
self
.
_test_costs_and_gradients
(
data
=
data
,
data
=
data
,
ref_costs
=
ref_costs
,
ref_costs
=
ref_costs
,
...
@@ -92,8 +79,7 @@ class RNNTLossTest:
...
@@ -92,8 +79,7 @@ class RNNTLossTest:
def
test_costs_and_gradients_random_data_with_numpy_fp32
(
self
):
def
test_costs_and_gradients_random_data_with_numpy_fp32
(
self
):
seed
=
777
seed
=
777
for
i
in
range
(
5
):
for
i
in
range
(
5
):
data
=
get_numpy_random_data
(
dtype
=
np
.
float32
,
seed
=
(
seed
+
i
))
data
=
get_random_data
(
dtype
=
torch
.
float32
,
device
=
self
.
device
,
seed
=
(
seed
+
i
))
data
=
numpy_to_torch
(
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
ref_costs
,
ref_gradients
=
compute_with_numpy_transducer
(
data
=
data
)
ref_costs
,
ref_gradients
=
compute_with_numpy_transducer
(
data
=
data
)
self
.
_test_costs_and_gradients
(
self
.
_test_costs_and_gradients
(
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
data
=
data
,
ref_costs
=
ref_costs
,
ref_gradients
=
ref_gradients
...
@@ -103,9 +89,8 @@ class RNNTLossTest:
...
@@ -103,9 +89,8 @@ class RNNTLossTest:
for
random
in
[
False
,
True
]:
for
random
in
[
False
,
True
]:
data
=
get_B1_T10_U3_D4_data
(
data
=
get_B1_T10_U3_D4_data
(
random
=
random
,
random
=
random
,
)
dtype
=
torch
.
float32
,
data
=
numpy_to_torch
(
device
=
self
.
device
,
data
=
data
,
device
=
self
.
device
,
requires_grad
=
True
)
)
data
[
"fused_log_softmax"
]
=
False
data
[
"fused_log_softmax"
]
=
False
ref_costs
,
ref_gradients
=
compute_with_numpy_transducer
(
ref_costs
,
ref_gradients
=
compute_with_numpy_transducer
(
...
...
test/torchaudio_unittest/rnnt/utils.py
View file @
0ea6d10d
import
unittest
import
unittest
import
random
import
numpy
as
np
import
torch
import
torch
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
from
torchaudio.prototype.rnnt_loss
import
RNNTLoss
...
@@ -19,10 +18,8 @@ def compute_with_numpy_transducer(data):
...
@@ -19,10 +18,8 @@ def compute_with_numpy_transducer(data):
loss
=
torch
.
sum
(
costs
)
loss
=
torch
.
sum
(
costs
)
loss
.
backward
()
loss
.
backward
()
costs
=
costs
.
cpu
()
costs
=
costs
.
cpu
().
data
.
numpy
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
().
data
.
numpy
()
return
costs
,
gradients
return
costs
,
gradients
...
@@ -41,12 +38,12 @@ def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
...
@@ -41,12 +38,12 @@ def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
loss
=
torch
.
sum
(
costs
)
loss
=
torch
.
sum
(
costs
)
loss
.
backward
()
loss
.
backward
()
costs
=
costs
.
cpu
()
.
data
.
numpy
()
costs
=
costs
.
cpu
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
()
.
data
.
numpy
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
()
return
costs
,
gradients
return
costs
,
gradients
def
get_
data_
basic
(
device
):
def
get_basic
_data
(
device
):
# Example provided
# Example provided
# in 6f73a2513dc784c59eec153a45f40bc528355b18
# in 6f73a2513dc784c59eec153a45f40bc528355b18
# of https://github.com/HawkAaron/warp-transducer
# of https://github.com/HawkAaron/warp-transducer
...
@@ -66,16 +63,12 @@ def get_data_basic(device):
...
@@ -66,16 +63,12 @@ def get_data_basic(device):
],
],
]
]
],
],
dtype
=
torch
.
float
,
dtype
=
torch
.
float32
,
device
=
device
,
)
)
targets
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int
)
targets
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int
,
device
=
device
)
logit_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
)
logit_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
,
device
=
device
)
target_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
)
target_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
,
device
=
device
)
logits
=
logits
.
to
(
device
=
device
)
targets
=
targets
.
to
(
device
=
device
)
logit_lengths
=
logit_lengths
.
to
(
device
=
device
)
target_lengths
=
target_lengths
.
to
(
device
=
device
)
logits
.
requires_grad_
(
True
)
logits
.
requires_grad_
(
True
)
...
@@ -84,27 +77,32 @@ def get_data_basic(device):
...
@@ -84,27 +77,32 @@ def get_data_basic(device):
def
get_B1_T10_U3_D4_data
(
def
get_B1_T10_U3_D4_data
(
random
=
False
,
random
=
False
,
dtype
=
np
.
float32
,
dtype
=
torch
.
float32
,
nan
=
False
,
device
=
torch
.
device
(
"cpu"
)
,
):
):
B
,
T
,
U
,
D
=
2
,
10
,
3
,
4
B
,
T
,
U
,
D
=
2
,
10
,
3
,
4
data
=
{}
data
[
"
logits
"
]
=
np
.
random
.
rand
(
B
,
T
,
U
,
D
).
as
type
(
dtype
)
logits
=
torch
.
rand
(
B
,
T
,
U
,
D
,
d
type
=
dtype
,
device
=
device
)
if
not
random
:
if
not
random
:
data
[
"logits"
].
fill
(
0.1
)
logits
.
fill_
(
0.1
)
if
nan
:
logits
.
requires_grad_
(
True
)
for
i
in
range
(
B
):
data
[
"logits"
][
i
][
0
][
0
][
0
]
=
np
.
nan
def
grad_hook
(
grad
):
data
[
"logit_lengths"
]
=
np
.
array
([
10
,
10
],
dtype
=
np
.
int32
)
logits
.
saved_grad
=
grad
.
clone
()
data
[
"target_lengths"
]
=
np
.
array
([
2
,
2
],
dtype
=
np
.
int32
)
logits
.
register_hook
(
grad_hook
)
data
[
"targets"
]
=
np
.
array
([[
1
,
2
],
[
1
,
2
]],
dtype
=
np
.
int32
)
data
=
{}
data
[
"logits"
]
=
logits
data
[
"logit_lengths"
]
=
torch
.
tensor
([
10
,
10
],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"target_lengths"
]
=
torch
.
tensor
([
2
,
2
],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"targets"
]
=
torch
.
tensor
([[
1
,
2
],
[
1
,
2
]],
dtype
=
torch
.
int32
,
device
=
device
)
data
[
"blank"
]
=
0
data
[
"blank"
]
=
0
return
data
return
data
def
get_
numpy_data_
B1_T2_U3_D5
(
dtype
=
np
.
float32
):
def
get_B1_T2_U3_D5
_data
(
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
)
):
logits
=
np
.
array
(
logits
=
torch
.
tensor
(
[
[
0.1
,
0.1
,
0.6
,
0.6
,
...
@@ -138,15 +136,22 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
...
@@ -138,15 +136,22 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
0.1
,
0.1
,
],
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
).
reshape
(
1
,
2
,
3
,
5
)
).
reshape
(
1
,
2
,
3
,
5
)
targets
=
np
.
array
([[
1
,
2
]],
dtype
=
np
.
int32
)
logits
.
requires_grad_
(
True
)
logit_lengths
=
np
.
array
([
2
],
dtype
=
np
.
int32
)
target_lengths
=
np
.
array
([
2
],
dtype
=
np
.
int32
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
targets
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int32
,
device
=
device
)
logit_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int32
,
device
=
device
)
target_lengths
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int32
,
device
=
device
)
blank
=
-
1
blank
=
-
1
ref_costs
=
np
.
array
([
5.09566688538
],
dtype
=
dtype
)
ref_costs
=
torch
.
tensor
([
5.09566688538
],
dtype
=
dtype
)
ref_gradients
=
np
.
array
(
ref_gradients
=
torch
.
tensor
(
[
[
0.17703132
,
0.17703132
,
-
0.39992708
,
-
0.39992708
,
...
@@ -193,10 +198,9 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
...
@@ -193,10 +198,9 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
return
data
,
ref_costs
,
ref_gradients
return
data
,
ref_costs
,
ref_gradients
def
get_
numpy_data_
B2_T4_U3_D3
(
dtype
=
np
.
float32
):
def
get_B2_T4_U3_D3
_data
(
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
)
):
# Test from D21322854
# Test from D21322854
logits
=
torch
.
tensor
(
logits
=
np
.
array
(
[
[
0.065357
,
0.065357
,
0.787530
,
0.787530
,
...
@@ -272,17 +276,23 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
...
@@ -272,17 +276,23 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
0.358021
,
0.358021
,
],
],
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
).
reshape
(
2
,
4
,
3
,
3
)
).
reshape
(
2
,
4
,
3
,
3
)
logits
.
requires_grad_
(
True
)
targets
=
np
.
array
([[
1
,
2
],
[
1
,
1
]],
dtype
=
np
.
int32
)
def
grad_hook
(
grad
):
logit_lengths
=
np
.
array
([
4
,
4
],
dtype
=
np
.
int32
)
logits
.
saved_grad
=
grad
.
clone
()
target_lengths
=
np
.
array
([
2
,
2
],
dtype
=
np
.
int32
)
logits
.
register_hook
(
grad_hook
)
targets
=
torch
.
tensor
([[
1
,
2
],
[
1
,
1
]],
dtype
=
torch
.
int32
,
device
=
device
)
logit_lengths
=
torch
.
tensor
([
4
,
4
],
dtype
=
torch
.
int32
,
device
=
device
)
target_lengths
=
torch
.
tensor
([
2
,
2
],
dtype
=
torch
.
int32
,
device
=
device
)
blank
=
0
blank
=
0
ref_costs
=
np
.
array
([
4.2806528590890736
,
3.9384369822503591
],
dtype
=
dtype
)
ref_costs
=
torch
.
tensor
([
4.2806528590890736
,
3.9384369822503591
],
dtype
=
dtype
)
ref_gradients
=
np
.
array
(
ref_gradients
=
torch
.
tensor
(
[
[
-
0.186844
,
-
0.186844
,
-
0.062555
,
-
0.062555
,
...
@@ -371,30 +381,45 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
...
@@ -371,30 +381,45 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
return
data
,
ref_costs
,
ref_gradients
return
data
,
ref_costs
,
ref_gradients
def
get_numpy_random_data
(
def
get_random_data
(
max_B
=
8
,
max_T
=
128
,
max_U
=
32
,
max_D
=
40
,
blank
=-
1
,
dtype
=
np
.
float32
,
seed
=
None
max_B
=
8
,
max_T
=
128
,
max_U
=
32
,
max_D
=
40
,
blank
=-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cpu"
),
seed
=
None
,
):
):
if
seed
is
not
None
:
if
seed
is
not
None
:
np
.
random
.
seed
(
seed
=
seed
)
torch
.
manual_
seed
(
seed
=
seed
)
if
blank
!=
-
1
:
if
blank
!=
-
1
:
raise
ValueError
(
"blank != -1 is not supported yet."
)
raise
ValueError
(
"blank != -1 is not supported yet."
)
B
=
np
.
random
.
randint
(
low
=
1
,
high
=
max_B
)
random
.
seed
(
0
)
T
=
np
.
random
.
randint
(
low
=
5
,
high
=
max_T
)
B
=
random
.
randint
(
1
,
max_B
-
1
)
U
=
np
.
random
.
randint
(
low
=
5
,
high
=
max_U
)
T
=
random
.
randint
(
5
,
max_T
-
1
)
D
=
np
.
random
.
randint
(
low
=
2
,
high
=
max_D
)
U
=
random
.
randint
(
5
,
max_U
-
1
)
D
=
random
.
randint
(
2
,
max_D
-
1
)
logit_lengths
=
np
.
random
.
randint
(
low
=
5
,
high
=
T
+
1
,
size
=
(
B
,),
dtype
=
np
.
int32
)
target_lengths
=
np
.
random
.
randint
(
low
=
5
,
high
=
U
+
1
,
size
=
(
B
,),
dtype
=
np
.
int32
)
logit_lengths
=
torch
.
randint
(
low
=
5
,
high
=
T
+
1
,
size
=
(
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
max_src_length
=
np
.
max
(
logit_lengths
)
target_lengths
=
torch
.
randint
(
low
=
5
,
high
=
U
+
1
,
size
=
(
B
,),
dtype
=
torch
.
int32
,
device
=
device
)
max_tgt_length
=
np
.
max
(
target_lengths
)
max_src_length
=
torch
.
max
(
logit_lengths
)
targets
=
np
.
random
.
randint
(
max_tgt_length
=
torch
.
max
(
target_lengths
)
low
=
0
,
high
=
D
-
1
,
size
=
(
B
,
max_tgt_length
),
dtype
=
np
.
int32
targets
=
torch
.
randint
(
low
=
0
,
high
=
D
-
1
,
size
=
(
B
,
max_tgt_length
),
dtype
=
torch
.
int32
,
device
=
device
)
)
logits
=
np
.
random
.
random_sample
(
logits
=
torch
.
rand
(
size
=
(
B
,
max_src_length
,
max_tgt_length
+
1
,
D
)
size
=
(
B
,
max_src_length
,
max_tgt_length
+
1
,
D
),
).
astype
(
dtype
=
dtype
)
dtype
=
dtype
,
device
=
device
,
).
requires_grad_
(
True
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
return
{
return
{
"logits"
:
logits
,
"logits"
:
logits
,
...
@@ -405,44 +430,6 @@ def get_numpy_random_data(
...
@@ -405,44 +430,6 @@ def get_numpy_random_data(
}
}
def
numpy_to_torch
(
data
,
device
,
requires_grad
=
True
):
logits
=
torch
.
from_numpy
(
data
[
"logits"
]).
to
(
device
=
device
)
targets
=
torch
.
from_numpy
(
data
[
"targets"
]).
to
(
device
=
device
)
logit_lengths
=
torch
.
from_numpy
(
data
[
"logit_lengths"
]).
to
(
device
=
device
)
target_lengths
=
torch
.
from_numpy
(
data
[
"target_lengths"
]).
to
(
device
=
device
)
if
"nbest_wers"
in
data
:
data
[
"nbest_wers"
]
=
torch
.
from_numpy
(
data
[
"nbest_wers"
]).
to
(
device
=
device
)
if
"nbest_scores"
in
data
:
data
[
"nbest_scores"
]
=
torch
.
from_numpy
(
data
[
"nbest_scores"
]).
to
(
device
=
device
)
logits
=
torch
.
autograd
.
Variable
(
logits
,
requires_grad
=
requires_grad
)
logit_lengths
=
torch
.
autograd
.
Variable
(
logit_lengths
)
target_lengths
=
torch
.
autograd
.
Variable
(
target_lengths
)
targets
=
torch
.
autograd
.
Variable
(
targets
)
if
device
==
torch
.
device
(
"cpu"
):
logits
=
logits
.
cpu
()
elif
device
==
torch
.
device
(
"cuda"
):
logits
=
logits
.
cuda
()
else
:
raise
ValueError
(
"unrecognized device = {}"
.
format
(
device
))
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
data
[
"logits"
]
=
logits
data
[
"logit_lengths"
]
=
logit_lengths
data
[
"target_lengths"
]
=
target_lengths
data
[
"targets"
]
=
targets
return
data
def
skipIfNoRNNT
(
test_item
):
def
skipIfNoRNNT
(
test_item
):
try
:
try
:
torch
.
ops
.
torchaudio
.
rnnt_loss
torch
.
ops
.
torchaudio
.
rnnt_loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment