Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
e7d9810d
Commit
e7d9810d
authored
Oct 31, 2022
by
pkufool
Browse files
Minor fixes
parent
b32f8a26
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
3 deletions
+15
-3
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+15
-3
No files found.
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
e7d9810d
...
@@ -591,15 +591,15 @@ def _adjust_pruning_lower_bound(
...
@@ -591,15 +591,15 @@ def _adjust_pruning_lower_bound(
"""
"""
# s_begin (B, T)
# s_begin (B, T)
(
B
,
T
)
=
s_begin
.
shape
(
B
,
T
)
=
s_begin
.
shape
_monotonic_lower_bound
(
s_begin
)
s_begin
=
_monotonic_lower_bound
(
s_begin
)
# do the magic transformation
# do the magic transformation
s_begin
=
-
(
s_begin
=
-
(
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
)
)
# make the transformed tensor to be non-decreasing
# make the transformed tensor to be non-decreasing
_monotonic_lower_bound
(
s_begin
)
s_begin
=
_monotonic_lower_bound
(
s_begin
)
# make start symbol to be zero.
# make start symbol to be zero.
s_begin
=
torch
.
where
(
s_begin
<
0
,
0
,
s_begin
)
s_begin
=
torch
.
clamp
(
s_begin
,
min
=
0
)
# do the magic transformation again to recover s_begin
# do the magic transformation again to recover s_begin
s_begin
=
-
(
s_begin
=
-
(
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
...
@@ -830,6 +830,12 @@ def get_rnnt_logprobs_pruned(
...
@@ -830,6 +830,12 @@ def get_rnnt_logprobs_pruned(
{0..C-1}.
{0..C-1}.
ranges:
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol:
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
the termination symbol, with 0 <= termination_symbol < C
boundary:
boundary:
...
@@ -996,6 +1002,12 @@ def rnnt_loss_pruned(
...
@@ -996,6 +1002,12 @@ def rnnt_loss_pruned(
of the sequence.
of the sequence.
ranges:
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol:
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
The identity of the termination symbol, must be in {0..C-1}
boundary:
boundary:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment