Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
6afe9951
Commit
6afe9951
authored
Aug 07, 2022
by
pkufool
Browse files
Fix potential bug and add more docs
parent
15a3d1cd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
6 deletions
+14
-6
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+14
-6
No files found.
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
6afe9951
...
@@ -557,9 +557,13 @@ def get_rnnt_prune_ranges(
...
@@ -557,9 +557,13 @@ def get_rnnt_prune_ranges(
s_range
>=
2
s_range
>=
2
),
"Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
),
"Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
(
B_stride
,
S_stride
,
T_stride
)
=
py_grad
.
stride
()
blk_grad
=
torch
.
as_strided
(
blk_grad
=
torch
.
as_strided
(
py_grad
,
(
B
,
S1
-
s_range
+
1
,
s_range
,
T
),
(
S1
*
T
,
T
,
T
,
1
)
py_grad
,
(
B
,
S1
-
s_range
+
1
,
s_range
,
T
),
(
B_stride
,
S_stride
,
S_stride
,
T_stride
),
)
)
# (B, S1 - s_range + 1, T)
# (B, S1 - s_range + 1, T)
blk_sum_grad
=
torch
.
sum
(
blk_grad
,
axis
=
2
)
blk_sum_grad
=
torch
.
sum
(
blk_grad
,
axis
=
2
)
...
@@ -572,13 +576,17 @@ def get_rnnt_prune_ranges(
...
@@ -572,13 +576,17 @@ def get_rnnt_prune_ranges(
# (B, T)
# (B, T)
s_begin
=
torch
.
argmax
(
final_grad
,
axis
=
1
)
s_begin
=
torch
.
argmax
(
final_grad
,
axis
=
1
)
s_begin
=
s_begin
[:,
:
T
]
# Handle the values of s_begin in padding positions.
# Handle the values of s_begin in padding positions.
# -1 here means we fill the position of the last frame
of real data
with
# -1 here means we fill the position of the last frame
(before padding)
with
# padding value which is `len(symbols) - s_range + 1`.
# padding value which is `len(symbols) - s_range + 1`.
# This is to guarantee that we reach the last symbol at last frame of real
# This is to guarantee that we reach the last symbol at last frame (before
# data.
# padding).
# The shape of the mask is (B, T), for example, we have a batch containing
# 3 sequences, their lengths are 3, 5, 6 (i.e. B = 3, T = 6), so the mask is
# [[True, True, False, False, False, False],
# [True, True, True, True, False, False],
# [True, True, True, True, True, False]]
mask
=
torch
.
arange
(
0
,
T
,
device
=
px_grad
.
device
).
reshape
(
1
,
T
).
expand
(
B
,
T
)
mask
=
torch
.
arange
(
0
,
T
,
device
=
px_grad
.
device
).
reshape
(
1
,
T
).
expand
(
B
,
T
)
mask
=
mask
<
boundary
[:,
3
].
reshape
(
B
,
1
)
-
1
mask
=
mask
<
boundary
[:,
3
].
reshape
(
B
,
1
)
-
1
...
@@ -589,7 +597,7 @@ def get_rnnt_prune_ranges(
...
@@ -589,7 +597,7 @@ def get_rnnt_prune_ranges(
s_begin
=
torch
.
where
(
mask
,
s_begin
,
s_begin_padding
)
s_begin
=
torch
.
where
(
mask
,
s_begin
,
s_begin_padding
)
# adjusting lower bound to make it satisfied some constrains, see docs in
# adjusting lower bound to make it satisfied some constrains, see docs in
# `adjust_pruning_lower_bound` for more details of these constrains.
# `
_
adjust_pruning_lower_bound` for more details of these constrains.
# T1 == T here means we are using the modified version of transducer,
# T1 == T here means we are using the modified version of transducer,
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# it only emits one symbol per frame.
# it only emits one symbol per frame.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment