Unverified Commit 878d7c81 authored by Yifan Yang's avatar Yifan Yang Committed by GitHub
Browse files

Update rnnt_loss.py

parent 79722682
......@@ -167,12 +167,10 @@ def get_rnnt_logprobs(
# px is the probs of the actual symbols..
px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C),
dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
).squeeze(
-1
) # [B][S][T]
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)
if rnnt_type == "regular":
px_am = torch.cat(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment