Unverified Commit 4bff8cc2 authored by Yifan Yang's avatar Yifan Yang Committed by GitHub
Browse files

Update rnnt_loss.py

parent 445aca29
...@@ -773,9 +773,9 @@ def do_rnnt_pruning( ...@@ -773,9 +773,9 @@ def do_rnnt_pruning(
lm, lm,
dim=1, dim=1,
index=ranges.reshape(B, T * s_range, 1).expand( index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, decoder_dim) (B, T * s_range, C)
), ),
).reshape(B, T, s_range, decoder_dim) ).reshape(B, T, s_range, C)
return am_pruning, lm_pruning return am_pruning, lm_pruning
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment