Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
15a3d1cd
Commit
15a3d1cd
authored
Jul 11, 2022
by
pkufool
Browse files
Fix pruning bounds
parent
134c1bcc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
177 additions
and
60 deletions
+177
-60
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+47
-25
fast_rnnt/python/tests/rnnt_loss_test.py
fast_rnnt/python/tests/rnnt_loss_test.py
+130
-35
No files found.
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
15a3d1cd
...
@@ -134,6 +134,8 @@ def get_rnnt_logprobs(
...
@@ -134,6 +134,8 @@ def get_rnnt_logprobs(
(
B
,
T
,
C
)
=
am
.
shape
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
S
>=
1
assert
T
>=
S
# subtracting am_max and lm_max is to ensure the probs are in a good range
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
# to do exp() without causing underflow or overflow.
...
@@ -331,6 +333,8 @@ def get_rnnt_logprobs_joint(
...
@@ -331,6 +333,8 @@ def get_rnnt_logprobs_joint(
(
B
,
T
,
S1
,
C
)
=
logits
.
shape
(
B
,
T
,
S1
,
C
)
=
logits
.
shape
S
=
S1
-
1
S
=
S1
-
1
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
S
>=
1
assert
T
>=
S
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
normalizers
=
normalizers
.
permute
((
0
,
2
,
1
))
normalizers
=
normalizers
.
permute
((
0
,
2
,
1
))
...
@@ -478,7 +482,9 @@ def _adjust_pruning_lower_bound(
...
@@ -478,7 +482,9 @@ def _adjust_pruning_lower_bound(
)
)
return
s_begin
return
s_begin
# To get more insight of how we calculate pruning bounds, please read
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf)
def
get_rnnt_prune_ranges
(
def
get_rnnt_prune_ranges
(
px_grad
:
torch
.
Tensor
,
px_grad
:
torch
.
Tensor
,
py_grad
:
torch
.
Tensor
,
py_grad
:
torch
.
Tensor
,
...
@@ -505,8 +511,8 @@ def get_rnnt_prune_ranges(
...
@@ -505,8 +511,8 @@ def get_rnnt_prune_ranges(
of symbols given a particular frame.
of symbols given a particular frame.
Note:
Note:
For the generated tensor ranges
, ranges[:, 0] is a monotonic increasing
For the generated tensor ranges
(assuming batch size is 1), ranges[:, 0]
tensor from 0 to `len(symbols)` and it satisfies
is a monotonic increasing
tensor from 0 to `len(symbols)` and it satisfies
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any
symbols.
symbols.
...
@@ -529,33 +535,43 @@ def get_rnnt_prune_ranges(
...
@@ -529,33 +535,43 @@ def get_rnnt_prune_ranges(
(
B
,
S
,
T1
)
=
px_grad
.
shape
(
B
,
S
,
T1
)
=
px_grad
.
shape
T
=
py_grad
.
shape
[
-
1
]
T
=
py_grad
.
shape
[
-
1
]
assert
T1
in
[
T
,
T
+
1
]
assert
T1
in
[
T
,
T
+
1
]
S1
=
S
+
1
assert
py_grad
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py_grad
.
shape
==
(
B
,
S
+
1
,
T
)
assert
boundary
.
shape
==
(
B
,
4
)
assert
boundary
.
shape
==
(
B
,
4
)
assert
s_range
>=
1
assert
S
>=
1
assert
T
>=
S
# s_range > S means we won't prune out any symbols. To make indexing with
# ranges runs normally, s_range should be equal to or less than ``S + 1``.
if
s_range
>
S
:
if
s_range
>
S
:
s_range
=
S
s_range
=
S
+
1
px_pad
=
torch
.
zeros
((
B
,
1
,
T1
),
dtype
=
px_grad
.
dtype
,
device
=
px_grad
.
device
)
if
T1
==
T
:
py_pad
=
torch
.
zeros
(
assert
(
(
B
,
S
+
1
,
1
),
dtype
=
py_grad
.
dtype
,
device
=
py_grad
.
device
s_range
>=
1
)
),
"Pruning range for modified RNN-T should be equal to or greater than 1, or no valid paths could survive pruning."
py_grad_padded
=
py_grad
if
T1
==
T
else
torch
.
cat
((
py_grad
,
py_pad
),
dim
=
2
)
tot_grad
=
(
torch
.
cat
((
px_grad
,
px_pad
),
dim
=
1
)
+
py_grad_padded
)
# (B, S + 1, T1)
tot_grad
=
torch
.
cat
(
else
:
(
assert
(
torch
.
zeros
(
s_range
>=
2
(
B
,
1
,
T1
),
dtype
=
tot_grad
.
dtype
,
device
=
tot_grad
.
device
),
"Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
),
tot_grad
,
blk_grad
=
torch
.
as_strided
(
),
py_grad
,
(
B
,
S1
-
s_range
+
1
,
s_range
,
T
),
(
S1
*
T
,
T
,
T
,
1
)
dim
=
1
,
)
)
tot_grad
=
torch
.
cumsum
(
tot_grad
,
dim
=
1
)
# (B, S1 - s_range + 1, T)
diff_grad
=
tot_grad
[:,
s_range
:,
:]
-
tot_grad
[:,
0
:
-
s_range
,
:]
blk_sum_grad
=
torch
.
sum
(
blk_grad
,
axis
=
2
)
s_begin
=
torch
.
argmax
(
diff_grad
,
dim
=
1
)
px_pad
=
torch
.
zeros
((
B
,
1
,
T1
),
dtype
=
px_grad
.
dtype
,
device
=
px_grad
.
device
)
# (B, S1, T)
px_grad_pad
=
torch
.
cat
((
px_pad
,
px_grad
),
dim
=
1
)
# (B, S1 - s_range + 1, T)
final_grad
=
blk_sum_grad
-
px_grad_pad
[:,
:
S1
-
s_range
+
1
,
:
T
]
# (B, T)
s_begin
=
torch
.
argmax
(
final_grad
,
axis
=
1
)
s_begin
=
s_begin
[:,
:
T
]
s_begin
=
s_begin
[:,
:
T
]
# Handle the values of s_begin in padding positions.
# Handle the values of s_begin in padding positions.
...
@@ -568,7 +584,7 @@ def get_rnnt_prune_ranges(
...
@@ -568,7 +584,7 @@ def get_rnnt_prune_ranges(
s_begin_padding
=
boundary
[:,
2
].
reshape
(
B
,
1
)
-
s_range
+
1
s_begin_padding
=
boundary
[:,
2
].
reshape
(
B
,
1
)
-
s_range
+
1
# handle the cases when `len(symbols) < s_range`
# handle the cases when `len(symbols) < s_range`
s_begin_padding
=
torch
.
where
(
s_begin_padding
>=
0
,
s_begin_padding
,
0
)
s_begin_padding
=
torch
.
clamp
(
s_begin_padding
,
min
=
0
)
s_begin
=
torch
.
where
(
mask
,
s_begin
,
s_begin_padding
)
s_begin
=
torch
.
where
(
mask
,
s_begin
,
s_begin_padding
)
...
@@ -578,9 +594,11 @@ def get_rnnt_prune_ranges(
...
@@ -578,9 +594,11 @@ def get_rnnt_prune_ranges(
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# it only emits one symbol per frame.
# it only emits one symbol per frame.
s_begin
=
_adjust_pruning_lower_bound
(
s_begin
,
2
if
T1
==
T
else
s_range
)
s_begin
=
_adjust_pruning_lower_bound
(
s_begin
,
2
if
T1
==
T
else
s_range
)
ranges
=
s_begin
.
reshape
((
B
,
T
,
1
)).
expand
((
B
,
T
,
s_range
))
+
torch
.
arange
(
ranges
=
s_begin
.
reshape
((
B
,
T
,
1
)).
expand
((
B
,
T
,
s_range
))
+
torch
.
arange
(
s_range
,
device
=
px_grad
.
device
s_range
,
device
=
px_grad
.
device
)
)
return
ranges
return
ranges
...
@@ -699,6 +717,8 @@ def get_rnnt_logprobs_pruned(
...
@@ -699,6 +717,8 @@ def get_rnnt_logprobs_pruned(
(
B
,
T
,
s_range
,
C
)
=
logits
.
shape
(
B
,
T
,
s_range
,
C
)
=
logits
.
shape
assert
ranges
.
shape
==
(
B
,
T
,
s_range
)
assert
ranges
.
shape
==
(
B
,
T
,
s_range
)
(
B
,
S
)
=
symbols
.
shape
(
B
,
S
)
=
symbols
.
shape
assert
S
>=
1
assert
T
>=
S
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
...
@@ -955,6 +975,8 @@ def get_rnnt_logprobs_smoothed(
...
@@ -955,6 +975,8 @@ def get_rnnt_logprobs_smoothed(
(
B
,
T
,
C
)
=
am
.
shape
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
S
>=
1
assert
T
>=
S
# Caution: some parts of this code are a little less clear than they could
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
# be due to optimizations. In particular it may not be totally obvious that
...
...
fast_rnnt/python/tests/rnnt_loss_test.py
View file @
15a3d1cd
...
@@ -120,11 +120,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -120,11 +120,11 @@ class TestRnntLoss(unittest.TestCase):
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
prob
s
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
logit
s
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
# test rnnt_loss
# test rnnt_loss
m
=
fast_rnnt
.
rnnt_loss
(
m
=
fast_rnnt
.
rnnt_loss
(
logits
=
prob
s
,
logits
=
logit
s
,
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
None
,
boundary
=
None
,
...
@@ -137,7 +137,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -137,7 +137,7 @@ class TestRnntLoss(unittest.TestCase):
import
torchaudio.functional
import
torchaudio.functional
m
=
torchaudio
.
functional
.
rnnt_loss
(
m
=
torchaudio
.
functional
.
rnnt_loss
(
logits
=
prob
s
,
logits
=
logit
s
,
targets
=
symbols
.
int
(),
targets
=
symbols
.
int
(),
logit_lengths
=
torch
.
tensor
(
logit_lengths
=
torch
.
tensor
(
[
T
]
*
B
,
dtype
=
torch
.
int32
,
device
=
device
[
T
]
*
B
,
dtype
=
torch
.
int32
,
device
=
device
...
@@ -176,9 +176,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -176,9 +176,9 @@ class TestRnntLoss(unittest.TestCase):
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
prob
s
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
logit
s
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
m
=
fast_rnnt
.
rnnt_loss
(
m
=
fast_rnnt
.
rnnt_loss
(
logits
=
prob
s
,
logits
=
logit
s
,
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
None
,
boundary
=
None
,
...
@@ -255,9 +255,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -255,9 +255,9 @@ class TestRnntLoss(unittest.TestCase):
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
prob
s
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
logit
s
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
m
=
fast_rnnt
.
rnnt_loss
(
m
=
fast_rnnt
.
rnnt_loss
(
logits
=
prob
s
,
logits
=
logit
s
,
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
...
@@ -270,7 +270,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -270,7 +270,7 @@ class TestRnntLoss(unittest.TestCase):
import
torchaudio.functional
import
torchaudio.functional
m
=
torchaudio
.
functional
.
rnnt_loss
(
m
=
torchaudio
.
functional
.
rnnt_loss
(
logits
=
prob
s
,
logits
=
logit
s
,
targets
=
symbols
.
int
(),
targets
=
symbols
.
int
(),
logit_lengths
=
boundary
[:,
3
].
int
(),
logit_lengths
=
boundary
[:,
3
].
int
(),
target_lengths
=
boundary
[:,
2
].
int
(),
target_lengths
=
boundary
[:,
2
].
int
(),
...
@@ -292,9 +292,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -292,9 +292,9 @@ class TestRnntLoss(unittest.TestCase):
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
prob
s
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
logit
s
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
m
=
fast_rnnt
.
rnnt_loss
(
m
=
fast_rnnt
.
rnnt_loss
(
logits
=
prob
s
,
logits
=
logit
s
,
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
...
@@ -345,32 +345,32 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -345,32 +345,32 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols_
.
to
(
device
)
symbols
=
symbols_
.
to
(
device
)
boundary
=
boundary_
.
to
(
device
)
boundary
=
boundary_
.
to
(
device
)
log
prob
s
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
log
it
s
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
log
prob
s
.
requires_grad_
()
log
it
s
.
requires_grad_
()
k2
_loss
=
fast_rnnt
.
rnnt_loss
(
fast
_loss
=
fast_rnnt
.
rnnt_loss
(
logits
=
log
prob
s
,
logits
=
log
it
s
,
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
)
)
k2
_grad
=
torch
.
autograd
.
grad
(
k2
_loss
,
log
prob
s
)
fast
_grad
=
torch
.
autograd
.
grad
(
fast
_loss
,
log
it
s
)
k2
_grad
=
k2
_grad
[
0
]
fast
_grad
=
fast
_grad
[
0
]
log
prob
s2
=
log
prob
s
.
detach
().
clone
().
float
()
log
it
s2
=
log
it
s
.
detach
().
clone
().
float
()
log
prob
s2
.
requires_grad_
()
log
it
s2
.
requires_grad_
()
torch_loss
=
torchaudio
.
functional
.
rnnt_loss
(
torch_loss
=
torchaudio
.
functional
.
rnnt_loss
(
logits
=
log
prob
s2
,
logits
=
log
it
s2
,
targets
=
symbols
.
int
(),
targets
=
symbols
.
int
(),
logit_lengths
=
boundary
[:,
3
].
int
(),
logit_lengths
=
boundary
[:,
3
].
int
(),
target_lengths
=
boundary
[:,
2
].
int
(),
target_lengths
=
boundary
[:,
2
].
int
(),
blank
=
termination_symbol
,
blank
=
termination_symbol
,
)
)
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
log
prob
s2
)
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
log
it
s2
)
torch_grad
=
torch_grad
[
0
]
torch_grad
=
torch_grad
[
0
]
assert
torch
.
allclose
(
k2
_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast
_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
k2
_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast
_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_rnnt_loss_smoothed
(
self
):
def
test_rnnt_loss_smoothed
(
self
):
B
=
1
B
=
1
...
@@ -450,14 +450,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -450,14 +450,13 @@ class TestRnntLoss(unittest.TestCase):
lm
=
lm_
.
to
(
device
)
lm
=
lm_
.
to
(
device
)
symbols
=
symbols_
.
to
(
device
)
symbols
=
symbols_
.
to
(
device
)
boundary
=
boundary_
.
to
(
device
)
boundary
=
boundary_
.
to
(
device
)
t_am
=
am
.
unsqueeze
(
2
).
float
()
logits
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
t_lm
=
lm
.
unsqueeze
(
1
).
float
()
logits
=
logits
.
float
()
t_prob
=
t_am
+
t_lm
# nonlinear transform
# nonlinear transform
t_prob
=
torch
.
sigmoid
(
t_prob
)
logits
=
torch
.
sigmoid
(
logits
)
k2
_loss
=
fast_rnnt
.
rnnt_loss
(
fast
_loss
=
fast_rnnt
.
rnnt_loss
(
logits
=
t_prob
,
logits
=
logits
,
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
...
@@ -465,11 +464,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -465,11 +464,11 @@ class TestRnntLoss(unittest.TestCase):
)
)
print
(
print
(
f
"
u
npruned rnnt loss with modified
{
modified
}
:
{
k2
_loss
}
"
f
"
U
npruned rnnt loss with modified
{
modified
}
:
{
fast
_loss
}
"
)
)
# pruning
# pruning
k2_
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
lm
=
lm
,
lm
=
lm
,
am
=
am
,
am
=
am
,
symbols
=
symbols
,
symbols
=
symbols
,
...
@@ -488,15 +487,15 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -488,15 +487,15 @@ class TestRnntLoss(unittest.TestCase):
s_range
=
r
,
s_range
=
r
,
)
)
# (B, T, r, C)
# (B, T, r, C)
am_p
,
lm_p
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
pruned_am
,
pruned_lm
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
t_prob_p
=
am_p
+
lm_p
logits
=
pruned_am
+
pruned_lm
# nonlinear transform
# nonlinear transform
t_prob_p
=
torch
.
sigmoid
(
t_prob_p
)
logits
=
torch
.
sigmoid
(
logits
)
pruned_loss
=
fast_rnnt
.
rnnt_loss_pruned
(
pruned_loss
=
fast_rnnt
.
rnnt_loss_pruned
(
logits
=
t_prob_p
,
logits
=
logits
,
symbols
=
symbols
,
symbols
=
symbols
,
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
...
@@ -504,8 +503,104 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -504,8 +503,104 @@ class TestRnntLoss(unittest.TestCase):
modified
=
modified
,
modified
=
modified
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
f
"
p
runing loss with range
{
r
}
:
{
pruned_loss
}
"
)
print
(
f
"
P
runing loss with range
{
r
}
:
{
pruned_loss
}
"
)
# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
def
test_rnnt_loss_pruned_small_symbols_number
(
self
):
B
=
2
T
=
20
S
=
3
C
=
10
frames
=
torch
.
randint
(
S
+
1
,
T
,
(
B
,))
seq_lengths
=
torch
.
randint
(
1
,
S
,
(
B
,))
T
=
torch
.
max
(
frames
)
S
=
torch
.
max
(
seq_lengths
)
am_
=
torch
.
randn
((
B
,
T
,
C
),
dtype
=
torch
.
float64
)
lm_
=
torch
.
randn
((
B
,
S
+
1
,
C
),
dtype
=
torch
.
float64
)
symbols_
=
torch
.
randint
(
0
,
C
,
(
B
,
S
))
terminal_symbol
=
C
-
1
boundary_
=
torch
.
zeros
((
B
,
4
),
dtype
=
torch
.
int64
)
boundary_
[:,
2
]
=
seq_lengths
boundary_
[:,
3
]
=
frames
print
(
f
"B =
{
B
}
, T =
{
T
}
, S =
{
S
}
, C =
{
C
}
"
)
for
modified
in
[
True
,
False
]:
for
device
in
self
.
devices
:
# normal rnnt
am
=
am_
.
to
(
device
)
lm
=
lm_
.
to
(
device
)
symbols
=
symbols_
.
to
(
device
)
boundary
=
boundary_
.
to
(
device
)
logits
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
logits
=
logits
.
float
()
# nonlinear transform
logits
=
torch
.
sigmoid
(
logits
)
loss
=
fast_rnnt
.
rnnt_loss
(
logits
=
logits
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
modified
=
modified
,
reduction
=
"none"
,
)
print
(
f
"Unpruned rnnt loss with modified
{
modified
}
:
{
loss
}
"
)
# pruning
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
modified
=
modified
,
return_grad
=
True
,
reduction
=
"none"
,
)
S0
=
2
if
modified
:
S0
=
1
for
r
in
range
(
S0
,
S
+
2
):
ranges
=
fast_rnnt
.
get_rnnt_prune_ranges
(
px_grad
=
px_grad
,
py_grad
=
py_grad
,
boundary
=
boundary
,
s_range
=
r
,
)
# (B, T, r, C)
pruned_am
,
pruned_lm
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
logits
=
pruned_am
+
pruned_lm
# nonlinear transform
logits
=
torch
.
sigmoid
(
logits
)
pruned_loss
=
fast_rnnt
.
rnnt_loss_pruned
(
logits
=
logits
,
symbols
=
symbols
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
modified
=
modified
,
reduction
=
"none"
,
)
print
(
f
"Pruned loss with range
{
r
}
:
{
pruned_loss
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment