Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
a801adc3
"docs/vscode:/vscode.git/clone" did not exist on "7e257cd666c0d639626487987ea8e590da1e9395"
Unverified
Commit
a801adc3
authored
Jul 19, 2023
by
Daniel Povey
Committed by
GitHub
Jul 19, 2023
Browse files
Merge pull request #24 from yfyeung/yfyeung-patch-1
Update rnnt_loss.py
parents
2945bd7d
878d7c81
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
12 deletions
+8
-12
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+8
-12
No files found.
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
a801adc3
...
...
@@ -167,12 +167,10 @@ def get_rnnt_logprobs(
# px is the probs of the actual symbols..
px_am
=
torch
.
gather
(
am
.
unsqueeze
(
1
).
expand
(
B
,
S
,
T
,
C
),
dim
=
3
,
index
=
symbols
.
reshape
(
B
,
S
,
1
,
1
).
expand
(
B
,
S
,
T
,
1
),
).
squeeze
(
-
1
)
# [B][S][T]
am
.
transpose
(
1
,
2
),
# (B, C, T)
dim
=
1
,
index
=
symbols
.
unsqueeze
(
2
).
expand
(
B
,
S
,
T
),
)
# (B, S, T)
if
rnnt_type
==
"regular"
:
px_am
=
torch
.
cat
(
...
...
@@ -1247,12 +1245,10 @@ def get_rnnt_logprobs_smoothed(
# px is the probs of the actual symbols (not yet normalized)..
px_am
=
torch
.
gather
(
am
.
unsqueeze
(
1
).
expand
(
B
,
S
,
T
,
C
),
dim
=
3
,
index
=
symbols
.
reshape
(
B
,
S
,
1
,
1
).
expand
(
B
,
S
,
T
,
1
),
).
squeeze
(
-
1
)
# [B][S][T]
am
.
transpose
(
1
,
2
),
# (B, C, T)
dim
=
1
,
index
=
symbols
.
unsqueeze
(
2
).
expand
(
B
,
S
,
T
),
)
# (B, S, T)
if
rnnt_type
==
"regular"
:
px_am
=
torch
.
cat
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment