"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "f13d65a7ea711ed3939917b959bc49b4701d719e"
Unverified Commit 445aca29 authored by Yifan Yang's avatar Yifan Yang Committed by GitHub
Browse files

Update rnnt_loss.py

parent 2c2dc4b9
......@@ -770,10 +770,12 @@ def do_rnnt_pruning(
# (B, T, s_range, C)
lm_pruning = torch.gather(
lm.unsqueeze(1).expand((B, T, S + 1, C)),
dim=2,
index=ranges.reshape((B, T, s_range, 1)).expand((B, T, s_range, C)),
)
lm,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, decoder_dim)
),
).reshape(B, T, s_range, decoder_dim)
return am_pruning, lm_pruning
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment