Commit e7d9810d authored by pkufool's avatar pkufool
Browse files

Minor fixes

parent b32f8a26
......@@ -591,15 +591,15 @@ def _adjust_pruning_lower_bound(
"""
# s_begin (B, T)
(B, T) = s_begin.shape
_monotonic_lower_bound(s_begin)
s_begin = _monotonic_lower_bound(s_begin)
# do the magic transformation
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
# make the transformed tensor to be non-decreasing
_monotonic_lower_bound(s_begin)
s_begin = _monotonic_lower_bound(s_begin)
# make start symbol to be zero.
s_begin = torch.where(s_begin < 0, 0, s_begin)
s_begin = torch.clamp(s_begin, min=0)
# do the magic transformation again to recover s_begin
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
......@@ -830,6 +830,12 @@ def get_rnnt_logprobs_pruned(
{0..C-1}.
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
......@@ -996,6 +1002,12 @@ def rnnt_loss_pruned(
of the sequence.
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment