Commit 0f4e97aa authored by Daniel Povey's avatar Daniel Povey
Browse files

Modify documentation to clarify that empty intervals are allowed.

parent b172f0bc
...@@ -161,7 +161,8 @@ def mutual_information_recursion(px, py, boundary=None): ...@@ -161,7 +161,8 @@ def mutual_information_recursion(px, py, boundary=None):
has stride of 1; this is true if px and py are contiguous. has stride of 1; this is true if px and py are contiguous.
boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values [s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T (this implies that empty sequences are allowed). If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and [0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length. respectively, and can be used if not all sequences are of the same length.
......
...@@ -15,7 +15,7 @@ def test_mutual_information_basic(): ...@@ -15,7 +15,7 @@ def test_mutual_information_basic():
random.randint(1, 200)) random.randint(1, 200))
random_px = (random.random() < 0.2) random_px = (random.random() < 0.2)
random_py = (random.random() < 0.2) random_py = (random.random() < 0.2)
random_boundary = (random.random() < 0.2) random_boundary = (random.random() < 0.7)
big_px = (random.random() < 0.2) big_px = (random.random() < 0.2)
big_py = (random.random() < 0.2) big_py = (random.random() < 0.2)
...@@ -32,8 +32,8 @@ def test_mutual_information_basic(): ...@@ -32,8 +32,8 @@ def test_mutual_information_basic():
def get_boundary_row(): def get_boundary_row():
s_begin = random.randint(0, S - 1) s_begin = random.randint(0, S - 1)
t_begin = random.randint(0, T - 1) t_begin = random.randint(0, T - 1)
s_end = random.randint(s_begin + 1, S) s_end = random.randint(s_begin, S) # allow empty sequence
t_end = random.randint(t_begin + 1, T) t_end = random.randint(t_begin, T) # allow empty sequence
return [s_begin, t_begin, s_end, t_end] return [s_begin, t_begin, s_end, t_end]
if device == torch.device('cpu'): if device == torch.device('cpu'):
boundary = torch.tensor([ get_boundary_row() for _ in range(B) ], boundary = torch.tensor([ get_boundary_row() for _ in range(B) ],
...@@ -73,7 +73,7 @@ def test_mutual_information_basic(): ...@@ -73,7 +73,7 @@ def test_mutual_information_basic():
#m = mutual_information_recursion(px, py, None) #m = mutual_information_recursion(px, py, None)
m = mutual_information_recursion(px, py, boundary) m = mutual_information_recursion(px, py, boundary)
#print("m = ", m, ", size = ", m.shape) print("m = ", m, ", size = ", m.shape)
#print("exp(m) = ", m.exp()) #print("exp(m) = ", m.exp())
(m.sum() * 3).backward() (m.sum() * 3).backward()
#print("px_grad = ", px.grad) #print("px_grad = ", px.grad)
...@@ -101,7 +101,7 @@ def test_mutual_information_deriv(): ...@@ -101,7 +101,7 @@ def test_mutual_information_deriv():
random.randint(1, 200)) random.randint(1, 200))
random_px = (random.random() < 0.2) random_px = (random.random() < 0.2)
random_py = (random.random() < 0.2) random_py = (random.random() < 0.2)
random_boundary = (random.random() < 0.2) random_boundary = (random.random() < 0.7)
big_px = (random.random() < 0.2) big_px = (random.random() < 0.2)
big_py = (random.random() < 0.2) big_py = (random.random() < 0.2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment