Commit 6afe9951 authored by pkufool's avatar pkufool
Browse files

Fix potential bug and add more docs

parent 15a3d1cd
......@@ -557,9 +557,13 @@ def get_rnnt_prune_ranges(
s_range >= 2
), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."
(B_stride, S_stride, T_stride) = py_grad.stride()
blk_grad = torch.as_strided(
py_grad, (B, S1 - s_range + 1, s_range, T), (S1 * T, T, T, 1)
py_grad,
(B, S1 - s_range + 1, s_range, T),
(B_stride, S_stride, S_stride, T_stride),
)
# (B, S1 - s_range + 1, T)
blk_sum_grad = torch.sum(blk_grad, axis=2)
......@@ -572,13 +576,17 @@ def get_rnnt_prune_ranges(
# (B, T)
s_begin = torch.argmax(final_grad, axis=1)
s_begin = s_begin[:, :T]
# Handle the values of s_begin in padding positions.
# -1 here means we fill the position of the last frame of real data with
# -1 here means we fill the position of the last frame (before padding) with
# padding value which is `len(symbols) - s_range + 1`.
# This is to guarantee that we reach the last symbol at last frame of real
# data.
# This is to guarantee that we reach the last symbol at last frame (before
# padding).
# The shape of the mask is (B, T), for example, we have a batch containing
# 3 sequences, their lengths are 3, 5, 6 (i.e. B = 3, T = 6), so the mask is
# [[True, True, False, False, False, False],
# [True, True, True, True, False, False],
# [True, True, True, True, True, False]]
mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T)
mask = mask < boundary[:, 3].reshape(B, 1) - 1
......@@ -589,7 +597,7 @@ def get_rnnt_prune_ranges(
s_begin = torch.where(mask, s_begin, s_begin_padding)
# adjusting lower bound to make it satisfied some constrains, see docs in
# `adjust_pruning_lower_bound` for more details of these constrains.
# `_adjust_pruning_lower_bound` for more details of these constrains.
# T1 == T here means we are using the modified version of transducer,
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# it only emits one symbol per frame.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment