Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
0f4e97aa
Commit
0f4e97aa
authored
Jul 31, 2021
by
Daniel Povey
Browse files
Modify documentation to clarify that empty intervals are allowed.
parent
b172f0bc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
6 deletions
+7
-6
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+2
-1
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+5
-5
No files found.
torch_mutual_information/mutual_information.py
View file @
0f4e97aa
...
@@ -161,7 +161,8 @@ def mutual_information_recursion(px, py, boundary=None):
...
@@ -161,7 +161,8 @@ def mutual_information_recursion(px, py, boundary=None):
has stride of 1; this is true if px and py are contiguous.
has stride of 1; this is true if px and py are contiguous.
boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T (this implies that empty sequences are allowed). If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
respectively, and can be used if not all sequences are of the same length.
...
...
torch_mutual_information/mutual_information_test.py
View file @
0f4e97aa
...
@@ -15,7 +15,7 @@ def test_mutual_information_basic():
...
@@ -15,7 +15,7 @@ def test_mutual_information_basic():
random
.
randint
(
1
,
200
))
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.2
)
random_px
=
(
random
.
random
()
<
0.2
)
random_py
=
(
random
.
random
()
<
0.2
)
random_py
=
(
random
.
random
()
<
0.2
)
random_boundary
=
(
random
.
random
()
<
0.
2
)
random_boundary
=
(
random
.
random
()
<
0.
7
)
big_px
=
(
random
.
random
()
<
0.2
)
big_px
=
(
random
.
random
()
<
0.2
)
big_py
=
(
random
.
random
()
<
0.2
)
big_py
=
(
random
.
random
()
<
0.2
)
...
@@ -32,8 +32,8 @@ def test_mutual_information_basic():
...
@@ -32,8 +32,8 @@ def test_mutual_information_basic():
def
get_boundary_row
():
def
get_boundary_row
():
s_begin
=
random
.
randint
(
0
,
S
-
1
)
s_begin
=
random
.
randint
(
0
,
S
-
1
)
t_begin
=
random
.
randint
(
0
,
T
-
1
)
t_begin
=
random
.
randint
(
0
,
T
-
1
)
s_end
=
random
.
randint
(
s_begin
+
1
,
S
)
s_end
=
random
.
randint
(
s_begin
,
S
)
# allow empty sequence
t_end
=
random
.
randint
(
t_begin
+
1
,
T
)
t_end
=
random
.
randint
(
t_begin
,
T
)
# allow empty sequence
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
if
device
==
torch
.
device
(
'cpu'
):
if
device
==
torch
.
device
(
'cpu'
):
boundary
=
torch
.
tensor
([
get_boundary_row
()
for
_
in
range
(
B
)
],
boundary
=
torch
.
tensor
([
get_boundary_row
()
for
_
in
range
(
B
)
],
...
@@ -73,7 +73,7 @@ def test_mutual_information_basic():
...
@@ -73,7 +73,7 @@ def test_mutual_information_basic():
#m = mutual_information_recursion(px, py, None)
#m = mutual_information_recursion(px, py, None)
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
#
print("m = ", m, ", size = ", m.shape)
print
(
"m = "
,
m
,
", size = "
,
m
.
shape
)
#print("exp(m) = ", m.exp())
#print("exp(m) = ", m.exp())
(
m
.
sum
()
*
3
).
backward
()
(
m
.
sum
()
*
3
).
backward
()
#print("px_grad = ", px.grad)
#print("px_grad = ", px.grad)
...
@@ -101,7 +101,7 @@ def test_mutual_information_deriv():
...
@@ -101,7 +101,7 @@ def test_mutual_information_deriv():
random
.
randint
(
1
,
200
))
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.2
)
random_px
=
(
random
.
random
()
<
0.2
)
random_py
=
(
random
.
random
()
<
0.2
)
random_py
=
(
random
.
random
()
<
0.2
)
random_boundary
=
(
random
.
random
()
<
0.
2
)
random_boundary
=
(
random
.
random
()
<
0.
7
)
big_px
=
(
random
.
random
()
<
0.2
)
big_px
=
(
random
.
random
()
<
0.2
)
big_py
=
(
random
.
random
()
<
0.2
)
big_py
=
(
random
.
random
()
<
0.2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment