Unverified Commit 4bff8cc2 authored by Yifan Yang's avatar Yifan Yang Committed by GitHub
Browse files

Update rnnt_loss.py

parent 445aca29
......@@ -773,9 +773,9 @@ def do_rnnt_pruning(
lm,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, decoder_dim)
(B, T * s_range, C)
),
).reshape(B, T, s_range, decoder_dim)
).reshape(B, T, s_range, C)
return am_pruning, lm_pruning
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment