Unverified Commit 878d7c81 authored by Yifan Yang's avatar Yifan Yang Committed by GitHub
Browse files

Update rnnt_loss.py

parent 79722682
...@@ -167,12 +167,10 @@ def get_rnnt_logprobs( ...@@ -167,12 +167,10 @@ def get_rnnt_logprobs(
# px is the probs of the actual symbols.. # px is the probs of the actual symbols..
px_am = torch.gather( px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C), am.transpose(1, 2), # (B, C, T)
dim=3, dim=1,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1), index=symbols.unsqueeze(2).expand(B, S, T),
).squeeze( ) # (B, S, T)
-1
) # [B][S][T]
if rnnt_type == "regular": if rnnt_type == "regular":
px_am = torch.cat( px_am = torch.cat(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment