Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
79722682
Unverified
Commit
79722682
authored
Apr 25, 2023
by
Yifan Yang
Committed by
GitHub
Apr 25, 2023
Browse files
Update rnnt_loss.py
parent
2c2dc4b9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
6 deletions
+4
-6
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+4
-6
No files found.
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
79722682
...
...
@@ -1245,12 +1245,10 @@ def get_rnnt_logprobs_smoothed(
# px is the probs of the actual symbols (not yet normalized)..
px_am
=
torch
.
gather
(
am
.
unsqueeze
(
1
).
expand
(
B
,
S
,
T
,
C
),
dim
=
3
,
index
=
symbols
.
reshape
(
B
,
S
,
1
,
1
).
expand
(
B
,
S
,
T
,
1
),
).
squeeze
(
-
1
)
# [B][S][T]
am
.
transpose
(
1
,
2
),
# (B, C, T)
dim
=
1
,
index
=
symbols
.
unsqueeze
(
2
).
expand
(
B
,
S
,
T
),
)
# (B, S, T)
if
rnnt_type
==
"regular"
:
px_am
=
torch
.
cat
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment