Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
d55f5481
Commit
d55f5481
authored
Aug 11, 2022
by
pkufool
Browse files
Fix contiguous
parent
c268c3d5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
9 deletions
+6
-9
fast_rnnt/python/fast_rnnt/mutual_information.py
fast_rnnt/python/fast_rnnt/mutual_information.py
+6
-4
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+0
-5
No files found.
fast_rnnt/python/fast_rnnt/mutual_information.py
View file @
d55f5481
...
@@ -285,9 +285,10 @@ def mutual_information_recursion(
...
@@ -285,9 +285,10 @@ def mutual_information_recursion(
for
s_begin
,
t_begin
,
s_end
,
t_end
in
boundary
.
tolist
():
for
s_begin
,
t_begin
,
s_end
,
t_end
in
boundary
.
tolist
():
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
assert
0
<=
t_begin
<=
t_end
<=
T
# The following assertions are for efficiency
assert
px
.
is_contiguous
()
# The following statements are for efficiency
assert
py
.
is_contiguous
()
px
,
py
=
px
.
is_contiguous
(),
py
.
is_contiguous
()
pxy_grads
=
[
None
,
None
]
pxy_grads
=
[
None
,
None
]
scores
=
MutualInformationRecursionFunction
.
apply
(
px
,
py
,
pxy_grads
,
scores
=
MutualInformationRecursionFunction
.
apply
(
px
,
py
,
pxy_grads
,
boundary
,
return_grad
)
boundary
,
return_grad
)
...
@@ -378,8 +379,9 @@ def joint_mutual_information_recursion(
...
@@ -378,8 +379,9 @@ def joint_mutual_information_recursion(
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
assert
0
<=
t_begin
<=
t_end
<=
T
# The following statements are for efficiency
px_tot
,
py_tot
=
px_tot
.
contiguous
(),
py_tot
.
contiguous
()
px_tot
,
py_tot
=
px_tot
.
contiguous
(),
py_tot
.
contiguous
()
# The following assertions are for efficiency
assert
px_tot
.
ndim
==
3
assert
px_tot
.
ndim
==
3
assert
py_tot
.
ndim
==
3
assert
py_tot
.
ndim
==
3
...
...
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
d55f5481
...
@@ -361,8 +361,6 @@ def get_rnnt_logprobs_joint(
...
@@ -361,8 +361,6 @@ def get_rnnt_logprobs_joint(
logits
[:,
:,
:,
termination_symbol
].
permute
((
0
,
2
,
1
)).
clone
()
logits
[:,
:,
:,
termination_symbol
].
permute
((
0
,
2
,
1
)).
clone
()
)
# [B][S+1][T]
)
# [B][S+1][T]
py
-=
normalizers
py
-=
normalizers
px
=
px
.
contiguous
()
py
=
py
.
contiguous
()
if
not
modified
:
if
not
modified
:
px
=
fix_for_boundary
(
px
,
boundary
)
px
=
fix_for_boundary
(
px
,
boundary
)
...
@@ -807,9 +805,6 @@ def get_rnnt_logprobs_pruned(
...
@@ -807,9 +805,6 @@ def get_rnnt_logprobs_pruned(
# (B, S + 1, T)
# (B, S + 1, T)
py
=
py
.
permute
((
0
,
2
,
1
))
py
=
py
.
permute
((
0
,
2
,
1
))
px
=
px
.
contiguous
()
py
=
py
.
contiguous
()
if
not
modified
:
if
not
modified
:
px
=
fix_for_boundary
(
px
,
boundary
)
px
=
fix_for_boundary
(
px
,
boundary
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment