Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
0f4e97aa
Commit
0f4e97aa
authored
Jul 31, 2021
by
Daniel Povey
Browse files
Modify documentation to clarify that empty intervals are allowed.
parent
b172f0bc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
6 deletions
+7
-6
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+2
-1
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+5
-5
No files found.
torch_mutual_information/mutual_information.py
View file @
0f4e97aa
...
...
@@ -161,7 +161,8 @@ def mutual_information_recursion(px, py, boundary=None):
has stride of 1; this is true if px and py are contiguous.
boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T (this implies that empty sequences are allowed). If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
...
...
torch_mutual_information/mutual_information_test.py
View file @
0f4e97aa
...
...
@@ -15,7 +15,7 @@ def test_mutual_information_basic():
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.2
)
random_py
=
(
random
.
random
()
<
0.2
)
random_boundary
=
(
random
.
random
()
<
0.
2
)
random_boundary
=
(
random
.
random
()
<
0.
7
)
big_px
=
(
random
.
random
()
<
0.2
)
big_py
=
(
random
.
random
()
<
0.2
)
...
...
@@ -32,8 +32,8 @@ def test_mutual_information_basic():
def
get_boundary_row
():
s_begin
=
random
.
randint
(
0
,
S
-
1
)
t_begin
=
random
.
randint
(
0
,
T
-
1
)
s_end
=
random
.
randint
(
s_begin
+
1
,
S
)
t_end
=
random
.
randint
(
t_begin
+
1
,
T
)
s_end
=
random
.
randint
(
s_begin
,
S
)
# allow empty sequence
t_end
=
random
.
randint
(
t_begin
,
T
)
# allow empty sequence
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
if
device
==
torch
.
device
(
'cpu'
):
boundary
=
torch
.
tensor
([
get_boundary_row
()
for
_
in
range
(
B
)
],
...
...
@@ -73,7 +73,7 @@ def test_mutual_information_basic():
#m = mutual_information_recursion(px, py, None)
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
#
print("m = ", m, ", size = ", m.shape)
print
(
"m = "
,
m
,
", size = "
,
m
.
shape
)
#print("exp(m) = ", m.exp())
(
m
.
sum
()
*
3
).
backward
()
#print("px_grad = ", px.grad)
...
...
@@ -101,7 +101,7 @@ def test_mutual_information_deriv():
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.2
)
random_py
=
(
random
.
random
()
<
0.2
)
random_boundary
=
(
random
.
random
()
<
0.
2
)
random_boundary
=
(
random
.
random
()
<
0.
7
)
big_px
=
(
random
.
random
()
<
0.2
)
big_py
=
(
random
.
random
()
<
0.2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment