Unverified Commit 445aca29 authored by Yifan Yang's avatar Yifan Yang Committed by GitHub
Browse files

Update rnnt_loss.py

parent 2c2dc4b9
...@@ -770,10 +770,12 @@ def do_rnnt_pruning( ...@@ -770,10 +770,12 @@ def do_rnnt_pruning(
# (B, T, s_range, C) # (B, T, s_range, C)
lm_pruning = torch.gather( lm_pruning = torch.gather(
lm.unsqueeze(1).expand((B, T, S + 1, C)), lm,
dim=2, dim=1,
index=ranges.reshape((B, T, s_range, 1)).expand((B, T, s_range, C)), index=ranges.reshape(B, T * s_range, 1).expand(
) (B, T * s_range, decoder_dim)
),
).reshape(B, T, s_range, decoder_dim)
return am_pruning, lm_pruning return am_pruning, lm_pruning
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment