Unverified Commit 2945bd7d authored by Daniel Povey's avatar Daniel Povey Committed by GitHub
Browse files

Merge pull request #26 from yfyeung/patch-1

Update rnnt_loss.py
parents 2c2dc4b9 4bff8cc2
......@@ -770,10 +770,12 @@ def do_rnnt_pruning(
# (B, T, s_range, C)
lm_pruning = torch.gather(
lm.unsqueeze(1).expand((B, T, S + 1, C)),
dim=2,
index=ranges.reshape((B, T, s_range, 1)).expand((B, T, s_range, C)),
)
lm,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, C)
),
).reshape(B, T, s_range, C)
return am_pruning, lm_pruning
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment