Commit e7d9810d authored by pkufool's avatar pkufool
Browse files

Minor fixes

parent b32f8a26
...@@ -591,15 +591,15 @@ def _adjust_pruning_lower_bound( ...@@ -591,15 +591,15 @@ def _adjust_pruning_lower_bound(
""" """
# s_begin (B, T) # s_begin (B, T)
(B, T) = s_begin.shape (B, T) = s_begin.shape
_monotonic_lower_bound(s_begin) s_begin = _monotonic_lower_bound(s_begin)
# do the magic transformation # do the magic transformation
s_begin = -( s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device) s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
) )
# make the transformed tensor to be non-decreasing # make the transformed tensor to be non-decreasing
_monotonic_lower_bound(s_begin) s_begin = _monotonic_lower_bound(s_begin)
# make start symbol to be zero. # make start symbol to be zero.
s_begin = torch.where(s_begin < 0, 0, s_begin) s_begin = torch.clamp(s_begin, min=0)
# do the magic transformation again to recover s_begin # do the magic transformation again to recover s_begin
s_begin = -( s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device) s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
...@@ -830,6 +830,12 @@ def get_rnnt_logprobs_pruned( ...@@ -830,6 +830,12 @@ def get_rnnt_logprobs_pruned(
{0..C-1}. {0..C-1}.
ranges: ranges:
A tensor containing the symbol ids for each frame that we want to keep. A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol: termination_symbol:
the termination symbol, with 0 <= termination_symbol < C the termination symbol, with 0 <= termination_symbol < C
boundary: boundary:
...@@ -996,6 +1002,12 @@ def rnnt_loss_pruned( ...@@ -996,6 +1002,12 @@ def rnnt_loss_pruned(
of the sequence. of the sequence.
ranges: ranges:
A tensor containing the symbol ids for each frame that we want to keep. A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol: termination_symbol:
The identity of the termination symbol, must be in {0..C-1} The identity of the termination symbol, must be in {0..C-1}
boundary: boundary:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment