Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
b0ed23ef
Commit
b0ed23ef
authored
Oct 30, 2022
by
pkufool
Browse files
Add constrained rnnt
parent
dc35168d
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
292 additions
and
152 deletions
+292
-152
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+254
-122
fast_rnnt/python/tests/rnnt_loss_test.py
fast_rnnt/python/tests/rnnt_loss_test.py
+38
-30
No files found.
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
b0ed23ef
This diff is collapsed.
Click to expand it.
fast_rnnt/python/tests/rnnt_loss_test.py
View file @
b0ed23ef
...
@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
assert
px
.
shape
==
(
B
,
S
,
T
+
1
)
assert
px
.
shape
==
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
m
=
fast_rnnt
.
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
None
)
m
=
fast_rnnt
.
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
None
)
if
device
==
torch
.
device
(
"cpu"
):
if
device
==
torch
.
device
(
"cpu"
):
expected
=
-
m
expected
=
-
m
...
@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
boundary_
[:,
3
]
=
frames
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# lm: [B][S+1][C]
# lm: [B][S+1][C]
lm
=
lm_
.
to
(
device
)
lm
=
lm_
.
to
(
device
)
...
@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
assert
(
px
.
shape
==
(
B
,
S
,
T
)
if
rnnt_type
!=
"regular"
else
(
B
,
S
,
T
+
1
)
)
)
assert
px
.
shape
==
(
B
,
S
,
T
)
if
modified
else
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
m
=
fast_rnnt
.
mutual_information_recursion
(
m
=
fast_rnnt
.
mutual_information_recursion
(
...
@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale
=
0.0
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
# compare with torchaudio rnnt_loss
# compare with torchaudio rnnt_loss
if
self
.
has_torch_rnnt_loss
and
not
modified
:
if
self
.
has_torch_rnnt_loss
and
rnnt_type
==
"regular"
:
import
torchaudio.functional
import
torchaudio.functional
m
=
torchaudio
.
functional
.
rnnt_loss
(
m
=
torchaudio
.
functional
.
rnnt_loss
(
...
@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale
=
0.0
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
logits2
)
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
logits2
)
torch_grad
=
torch_grad
[
0
]
torch_grad
=
torch_grad
[
0
]
assert
torch
.
allclose
(
fast_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_rnnt_loss_smoothed
(
self
):
def
test_rnnt_loss_smoothed
(
self
):
B
=
1
B
=
1
...
@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
boundary_
[:,
3
]
=
frames
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# normal rnnt
# normal rnnt
am
=
am_
.
to
(
device
)
am
=
am_
.
to
(
device
)
...
@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
print
(
print
(
f
"Unpruned rnnt loss with
{
rnnt_loss
}
rnnt :
{
fast_loss
}
"
)
f
"Unpruned rnnt loss with modified
{
modified
}
:
{
fast_loss
}
"
)
# pruning
# pruning
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
...
@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
return_grad
=
True
,
return_grad
=
True
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
...
@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
s_range
=
r
,
s_range
=
r
,
)
)
# (B, T, r, C)
# (B, T, r, C)
pruned_am
,
pruned_lm
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
pruned_am
,
pruned_lm
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
logits
=
pruned_am
+
pruned_lm
logits
=
pruned_am
+
pruned_lm
...
@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
f
"Pruning loss with range
{
r
}
:
{
pruned_loss
}
"
)
print
(
f
"Pruning loss with range
{
r
}
:
{
pruned_loss
}
"
)
# Test the sequences that only have small number of symbols,
# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
# raise errors (like, nan or inf loss) in our previous versions.
...
@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
print
(
f
"B =
{
B
}
, T =
{
T
}
, S =
{
S
}
, C =
{
C
}
"
)
print
(
f
"B =
{
B
}
, T =
{
T
}
, S =
{
S
}
, C =
{
C
}
"
)
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# normal rnnt
# normal rnnt
am
=
am_
.
to
(
device
)
am
=
am_
.
to
(
device
)
...
@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
print
(
f
"Unpruned rnnt loss with
{
rnnt_type
}
rnnt :
{
loss
}
"
)
f
"Unpruned rnnt loss with modified
{
modified
}
:
{
loss
}
"
)
# pruning
# pruning
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
...
@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
return_grad
=
True
,
return_grad
=
True
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
S0
=
2
S0
=
2
if
modified
:
if
rnnt_type
==
"regular"
:
S0
=
1
S0
=
1
for
r
in
range
(
S0
,
S
+
2
):
for
r
in
range
(
S0
,
S
+
2
):
...
@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
f
"Pruned loss with range
{
r
}
:
{
pruned_loss
}
"
)
print
(
f
"Pruned loss with range
{
r
}
:
{
pruned_loss
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment