Unverified Commit a801adc3 authored by Daniel Povey's avatar Daniel Povey Committed by GitHub
Browse files

Merge pull request #24 from yfyeung/yfyeung-patch-1

Update rnnt_loss.py
parents 2945bd7d 878d7c81
...@@ -167,12 +167,10 @@ def get_rnnt_logprobs( ...@@ -167,12 +167,10 @@ def get_rnnt_logprobs(
# px is the probs of the actual symbols.. # px is the probs of the actual symbols..
px_am = torch.gather( px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C), am.transpose(1, 2), # (B, C, T)
dim=3, dim=1,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1), index=symbols.unsqueeze(2).expand(B, S, T),
).squeeze( ) # (B, S, T)
-1
) # [B][S][T]
if rnnt_type == "regular": if rnnt_type == "regular":
px_am = torch.cat( px_am = torch.cat(
...@@ -1247,12 +1245,10 @@ def get_rnnt_logprobs_smoothed( ...@@ -1247,12 +1245,10 @@ def get_rnnt_logprobs_smoothed(
# px is the probs of the actual symbols (not yet normalized).. # px is the probs of the actual symbols (not yet normalized)..
px_am = torch.gather( px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C), am.transpose(1, 2), # (B, C, T)
dim=3, dim=1,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1), index=symbols.unsqueeze(2).expand(B, S, T),
).squeeze( ) # (B, S, T)
-1
) # [B][S][T]
if rnnt_type == "regular": if rnnt_type == "regular":
px_am = torch.cat( px_am = torch.cat(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment