Commit 0f4e97aa authored by Daniel Povey's avatar Daniel Povey
Browse files

Modify documentation to clarify that empty intervals are allowed.

parent b172f0bc
......@@ -161,7 +161,8 @@ def mutual_information_recursion(px, py, boundary=None):
has stride of 1; this is true if px and py are contiguous.
boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T (this implies that empty sequences are allowed). If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
......
......@@ -15,7 +15,7 @@ def test_mutual_information_basic():
random.randint(1, 200))
random_px = (random.random() < 0.2)
random_py = (random.random() < 0.2)
random_boundary = (random.random() < 0.2)
random_boundary = (random.random() < 0.7)
big_px = (random.random() < 0.2)
big_py = (random.random() < 0.2)
......@@ -32,8 +32,8 @@ def test_mutual_information_basic():
def get_boundary_row():
s_begin = random.randint(0, S - 1)
t_begin = random.randint(0, T - 1)
s_end = random.randint(s_begin + 1, S)
t_end = random.randint(t_begin + 1, T)
s_end = random.randint(s_begin, S) # allow empty sequence
t_end = random.randint(t_begin, T) # allow empty sequence
return [s_begin, t_begin, s_end, t_end]
if device == torch.device('cpu'):
boundary = torch.tensor([ get_boundary_row() for _ in range(B) ],
......@@ -73,7 +73,7 @@ def test_mutual_information_basic():
#m = mutual_information_recursion(px, py, None)
m = mutual_information_recursion(px, py, boundary)
#print("m = ", m, ", size = ", m.shape)
print("m = ", m, ", size = ", m.shape)
#print("exp(m) = ", m.exp())
(m.sum() * 3).backward()
#print("px_grad = ", px.grad)
......@@ -101,7 +101,7 @@ def test_mutual_information_deriv():
random.randint(1, 200))
random_px = (random.random() < 0.2)
random_py = (random.random() < 0.2)
random_boundary = (random.random() < 0.2)
random_boundary = (random.random() < 0.7)
big_px = (random.random() < 0.2)
big_py = (random.random() < 0.2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment