Commit 1ad556dc authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix some bugs..

parent 52ae49ee
from .mutual_information import mutual_information
from .mutual_information import mutual_information_recursion
import os
import torch
from typing import Tuple
from torch import Tensor
from typing import Tuple, Optional
from torch.utils.cpp_extension import load
VERBOSE = False
......@@ -44,18 +45,18 @@ except ImportError:
def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if input.is_cuda:
boundary: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if px.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_mutual_information_cuda.mutual_information_cuda(
px, py, boundaries, p)
px, py, boundary, p)
else:
return torch_mutual_information_cpu.mutual_information_cpu(
px, py, boundaries, p)
px, py, boundary, p)
def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, p: torch.Tensor,
boundary: torch.Tensor, p: torch.Tensor,
ans_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if px.is_cuda:
if torch_mutual_information_cuda is None:
......@@ -64,7 +65,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
if overwrite_ans_grad:
ans_grad_copy = ans_grad.clone()
ans = tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
px, py, boundaries, p, ans_grad_copy, overwrite_ans_grad))
px, py, boundary, p, ans_grad_copy, overwrite_ans_grad))
if overwrite_ans_grad:
if not torch.allclose(ans_grad, ans_grad_copy, rtol=1.0e-02):
print(f"Warning: possible excsssive roundoff in mutual information backward "
......@@ -72,18 +73,20 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
return ans
else:
return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu(
px, py, boundaries, p, ans_grad))
px, py, boundary, p, ans_grad))
class MutualInformationRecursionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: Optional[torch.Tensor]) -> torch.Tensor:
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundary: Optional[torch.Tensor]) -> torch.Tensor:
(B, S, T1) = px.shape
T = T1 - 1;
assert py.shape == (B, S + 1, T)
if boundaries is not None:
assert boundaries.shape == (B, 4)
if boundary is not None:
assert boundary.shape == (B, 4)
else:
boundary = torch.zeros(0, 0, dtype=torch.int64, device=px.device)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
......@@ -101,20 +104,23 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, p)
ans = _mutual_information_forward_dispatcher(px, py, boundary, p)
print(f"p = {p}, boundary = {boundary}")
if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, p)
ctx.save_for_backward(px, py, boundary, p)
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
(px, py, boundaries, p) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, p, ans_grad)
(px, py, boundary, p) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundary, p, ans_grad)
return (px_grad, py_grad, None)
def mutual_information_recursion(input, px, py, boundaries=None):
def mutual_information_recursion(px, py, boundary=None):
"""A recursion that is useful in computing mutual information between two sequences of
real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
......@@ -154,7 +160,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
is that for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
......@@ -182,8 +188,9 @@ def mutual_information_recursion(input, px, py, boundaries=None):
assert py.shape == (B, S + 1, T)
assert px.dtype == py.dtype
(B, S, T) = px.shape
if boundaries is not None:
assert boundaries.dtype == torch.LongTensor
assert boundaries.shape == (B, 4)
if boundary is not None:
assert boundary.dtype == torch.LongTensor
assert boundary.shape == (B, 4)
return MutualInformationRecursion.apply(px, py, boundaries)
return MutualInformationRecursionFunction.apply(px, py, boundary)
......@@ -519,7 +519,7 @@ void mutual_information_backward_kernel(
// comments. We'll focus, in the comments, on differences from the forward
// pass.
const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1,
// num_t_blocks = T / BLOCK_SIZE + 1,
num_blocks_this_iter = min(iter + 1, num_s_blocks);
......@@ -668,7 +668,7 @@ void mutual_information_backward_kernel(
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
} else if (static_cast<unsigned int>((int)threadIdx.x - 64) <
static_cast<unsigned int>(block_T)) {
// casting to unsigned before the comparison tests for both negative and
......@@ -678,7 +678,7 @@ void mutual_information_backward_kernel(
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
}
// The highest-numbered value in p_buf that we need (corresponding,
......
......@@ -3,72 +3,24 @@
import random
import torch
from torch_mutual_information import mutual_information
from torch_mutual_information import mutual_information_recursion
def test_mutual_information_basic():
print("Running test_mutual_information_basic()")
for dtype in [torch.float32, torch.float64]:
B = 2
C = 4
T = 10
x = -2.0 + 0.4 * torch.arange(10, dtype=dtype)
x = x.reshape(1, 1, 10).repeat(B, C, 1)
S = 4
T = 5
px = torch.zeros(B, S, T + 1) # log of an odds ratio
py = torch.zeros(B, S + 1, T) # log of an odds ratio
K = 4
N = K * 2
params = torch.arange(N + 1, dtype=dtype).unsqueeze(0) + torch.arange(C, dtype=dtype).unsqueeze(1) - 3
x.requires_grad = True
params.requires_grad = True
print("x = ", x)
print("params = ", params)
print("x.shape = ", x.shape)
m = mutual_information_recursion(px, py)
y = mutual_information(x, params, dim = 1)
print("m = ", m)
if True:
# Check
x2 = x.reshape(B, C, 5, 2)
assert torch.allclose(mutual_information(x, params, dim = 1), mutual_information(x2, params, dim = 1).reshape(x.shape))
x2 = x.reshape(B, 1, C, 10)
assert torch.allclose(mutual_information(x, params, dim = 1), mutual_information(x2, params, dim = 2).reshape(x.shape))
print("y = ", y)
y.sum().backward()
if torch.cuda.is_available():
# test that the CUDA forward is the same as the CPU forward.
device = torch.device('cuda:0')
x2 = x.to(device).detach()
x2.requires_grad = True
params2 = params.to(device).detach()
params2.requires_grad = True
y2 = mutual_information(x2, params2, dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}")
assert(0);
y2.sum().backward()
if not torch.allclose(x.grad, x2.grad.to('cpu'), atol=1.0e-06):
print(f"Error: CPU x-grad versus CUDA grad not the same: {x.grad} vs. {x2.grad}, diff = {x2.grad.to('cpu')-x.grad}")
assert(0);
if not torch.allclose(params.grad, params2.grad.to('cpu'), atol=1.0e-06):
print(f"Error: CPU params-grad versus CUDA grad not the same: {params.grad} vs. {params2.grad}, diff = {params2.grad.to('cpu')-params.grad}")
assert(0);
print("x.grad = ", x.grad)
print("params.grad = ", params.grad)
# Just eyeballing the above tgo make sure it looks reasonable.
def test_mutual_information_deriv():
""" Tests derivatives in randomized way """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment