Commit 9ac065f0 authored by Daniel Povey's avatar Daniel Povey
Browse files

Some bug fixes, testing..

parent 1ad556dc
...@@ -189,7 +189,7 @@ def mutual_information_recursion(px, py, boundary=None): ...@@ -189,7 +189,7 @@ def mutual_information_recursion(px, py, boundary=None):
assert px.dtype == py.dtype assert px.dtype == py.dtype
(B, S, T) = px.shape (B, S, T) = px.shape
if boundary is not None: if boundary is not None:
assert boundary.dtype == torch.LongTensor assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4) assert boundary.shape == (B, 4)
......
...@@ -9,16 +9,19 @@ from torch_mutual_information import mutual_information_recursion ...@@ -9,16 +9,19 @@ from torch_mutual_information import mutual_information_recursion
def test_mutual_information_basic(): def test_mutual_information_basic():
print("Running test_mutual_information_basic()") print("Running test_mutual_information_basic()")
for dtype in [torch.float32, torch.float64]: for dtype in [torch.float32, torch.float64]:
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
B = 2 B = 2
S = 4 S = 4
T = 5 T = 5
px = torch.zeros(B, S, T + 1) # log of an odds ratio boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
py = torch.zeros(B, S + 1, T) # log of an odds ratio px = torch.zeros(B, S, T + 1).to(device) # log of an odds ratio
py = torch.zeros(B, S + 1, T).to(device) # log of an odds ratio
m = mutual_information_recursion(px, py) m = mutual_information_recursion(px, py, boundary)
print("m = ", m) print("m = ", m)
print("exp(m) = ", m.exp())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment