Commit 1ad556dc authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix some bugs..

parent 52ae49ee
from .mutual_information import mutual_information from .mutual_information import mutual_information_recursion
import os import os
import torch import torch
from typing import Tuple from torch import Tensor
from typing import Tuple, Optional
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
VERBOSE = False VERBOSE = False
...@@ -44,18 +45,18 @@ except ImportError: ...@@ -44,18 +45,18 @@ except ImportError:
def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor, def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, p: torch.Tensor) -> torch.Tensor: boundary: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if input.is_cuda: if px.is_cuda:
if torch_mutual_information_cuda is None: if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module') raise EnvironmentError(f'Failed to load native CUDA module')
return torch_mutual_information_cuda.mutual_information_cuda( return torch_mutual_information_cuda.mutual_information_cuda(
px, py, boundaries, p) px, py, boundary, p)
else: else:
return torch_mutual_information_cpu.mutual_information_cpu( return torch_mutual_information_cpu.mutual_information_cpu(
px, py, boundaries, p) px, py, boundary, p)
def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor, def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, p: torch.Tensor, boundary: torch.Tensor, p: torch.Tensor,
ans_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ans_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if px.is_cuda: if px.is_cuda:
if torch_mutual_information_cuda is None: if torch_mutual_information_cuda is None:
...@@ -64,7 +65,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor, ...@@ -64,7 +65,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
if overwrite_ans_grad: if overwrite_ans_grad:
ans_grad_copy = ans_grad.clone() ans_grad_copy = ans_grad.clone()
ans = tuple(torch_mutual_information_cuda.mutual_information_backward_cuda( ans = tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
px, py, boundaries, p, ans_grad_copy, overwrite_ans_grad)) px, py, boundary, p, ans_grad_copy, overwrite_ans_grad))
if overwrite_ans_grad: if overwrite_ans_grad:
if not torch.allclose(ans_grad, ans_grad_copy, rtol=1.0e-02): if not torch.allclose(ans_grad, ans_grad_copy, rtol=1.0e-02):
print(f"Warning: possible excsssive roundoff in mutual information backward " print(f"Warning: possible excsssive roundoff in mutual information backward "
...@@ -72,18 +73,20 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor, ...@@ -72,18 +73,20 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
return ans return ans
else: else:
return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu( return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu(
px, py, boundaries, p, ans_grad)) px, py, boundary, p, ans_grad))
class MutualInformationRecursionFunction(torch.autograd.Function): class MutualInformationRecursionFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: Optional[torch.Tensor]) -> torch.Tensor: def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundary: Optional[torch.Tensor]) -> torch.Tensor:
(B, S, T1) = px.shape (B, S, T1) = px.shape
T = T1 - 1; T = T1 - 1;
assert py.shape == (B, S + 1, T) assert py.shape == (B, S + 1, T)
if boundaries is not None: if boundary is not None:
assert boundaries.shape == (B, 4) assert boundary.shape == (B, 4)
else:
boundary = torch.zeros(0, 0, dtype=torch.int64, device=px.device)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the # p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
...@@ -101,20 +104,23 @@ class MutualInformationRecursionFunction(torch.autograd.Function): ...@@ -101,20 +104,23 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype) p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, p) ans = _mutual_information_forward_dispatcher(px, py, boundary, p)
print(f"p = {p}, boundary = {boundary}")
if px.requires_grad or py.requires_grad: if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, p) ctx.save_for_backward(px, py, boundary, p)
return ans
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]: def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
(px, py, boundaries, p) = ctx.saved_tensors (px, py, boundary, p) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, p, ans_grad) (px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundary, p, ans_grad)
return (px_grad, py_grad, None) return (px_grad, py_grad, None)
def mutual_information_recursion(input, px, py, boundaries=None): def mutual_information_recursion(px, py, boundary=None):
"""A recursion that is useful in computing mutual information between two sequences of """A recursion that is useful in computing mutual information between two sequences of
real vectors, but may be useful more generally in sequence-to-sequence tasks where real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of monotonic alignment between pairs of sequences is desired. The definitions of
...@@ -154,7 +160,7 @@ def mutual_information_recursion(input, px, py, boundaries=None): ...@@ -154,7 +160,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
is that for optimization purposes we assume the last axis (the t axis) is that for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous. has stride of 1; this is true if px and py are contiguous.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values [s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and [0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences one-past-the-last positions in the x and y sequences
...@@ -182,8 +188,9 @@ def mutual_information_recursion(input, px, py, boundaries=None): ...@@ -182,8 +188,9 @@ def mutual_information_recursion(input, px, py, boundaries=None):
assert py.shape == (B, S + 1, T) assert py.shape == (B, S + 1, T)
assert px.dtype == py.dtype assert px.dtype == py.dtype
(B, S, T) = px.shape (B, S, T) = px.shape
if boundaries is not None: if boundary is not None:
assert boundaries.dtype == torch.LongTensor assert boundary.dtype == torch.LongTensor
assert boundaries.shape == (B, 4) assert boundary.shape == (B, 4)
return MutualInformationRecursion.apply(px, py, boundaries) return MutualInformationRecursionFunction.apply(px, py, boundary)
...@@ -519,7 +519,7 @@ void mutual_information_backward_kernel( ...@@ -519,7 +519,7 @@ void mutual_information_backward_kernel(
// comments. We'll focus, in the comments, on differences from the forward // comments. We'll focus, in the comments, on differences from the forward
// pass. // pass.
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1, // num_t_blocks = T / BLOCK_SIZE + 1,
num_blocks_this_iter = min(iter + 1, num_s_blocks); num_blocks_this_iter = min(iter + 1, num_s_blocks);
...@@ -668,7 +668,7 @@ void mutual_information_backward_kernel( ...@@ -668,7 +668,7 @@ void mutual_information_backward_kernel(
s = s_in_block + s_block_begin, s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin; t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = ( p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0); s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
} else if (static_cast<unsigned int>((int)threadIdx.x - 64) < } else if (static_cast<unsigned int>((int)threadIdx.x - 64) <
static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(block_T)) {
// casting to unsigned before the comparison tests for both negative and // casting to unsigned before the comparison tests for both negative and
...@@ -678,7 +678,7 @@ void mutual_information_backward_kernel( ...@@ -678,7 +678,7 @@ void mutual_information_backward_kernel(
s = s_in_block + s_block_begin, s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin; t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = ( p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0); s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
} }
// The highest-numbered value in p_buf that we need (corresponding, // The highest-numbered value in p_buf that we need (corresponding,
......
...@@ -3,72 +3,24 @@ ...@@ -3,72 +3,24 @@
import random import random
import torch import torch
from torch_mutual_information import mutual_information from torch_mutual_information import mutual_information_recursion
def test_mutual_information_basic(): def test_mutual_information_basic():
print("Running test_mutual_information_basic()")
for dtype in [torch.float32, torch.float64]: for dtype in [torch.float32, torch.float64]:
B = 2 B = 2
C = 4 S = 4
T = 10 T = 5
x = -2.0 + 0.4 * torch.arange(10, dtype=dtype) px = torch.zeros(B, S, T + 1) # log of an odds ratio
x = x.reshape(1, 1, 10).repeat(B, C, 1) py = torch.zeros(B, S + 1, T) # log of an odds ratio
K = 4 m = mutual_information_recursion(px, py)
N = K * 2
params = torch.arange(N + 1, dtype=dtype).unsqueeze(0) + torch.arange(C, dtype=dtype).unsqueeze(1) - 3
x.requires_grad = True
params.requires_grad = True
print("x = ", x)
print("params = ", params)
print("x.shape = ", x.shape)
y = mutual_information(x, params, dim = 1) print("m = ", m)
if True:
# Check
x2 = x.reshape(B, C, 5, 2)
assert torch.allclose(mutual_information(x, params, dim = 1), mutual_information(x2, params, dim = 1).reshape(x.shape))
x2 = x.reshape(B, 1, C, 10)
assert torch.allclose(mutual_information(x, params, dim = 1), mutual_information(x2, params, dim = 2).reshape(x.shape))
print("y = ", y)
y.sum().backward()
if torch.cuda.is_available():
# test that the CUDA forward is the same as the CPU forward.
device = torch.device('cuda:0')
x2 = x.to(device).detach()
x2.requires_grad = True
params2 = params.to(device).detach()
params2.requires_grad = True
y2 = mutual_information(x2, params2, dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}")
assert(0);
y2.sum().backward()
if not torch.allclose(x.grad, x2.grad.to('cpu'), atol=1.0e-06):
print(f"Error: CPU x-grad versus CUDA grad not the same: {x.grad} vs. {x2.grad}, diff = {x2.grad.to('cpu')-x.grad}")
assert(0);
if not torch.allclose(params.grad, params2.grad.to('cpu'), atol=1.0e-06):
print(f"Error: CPU params-grad versus CUDA grad not the same: {params.grad} vs. {params2.grad}, diff = {params2.grad.to('cpu')-params.grad}")
assert(0);
print("x.grad = ", x.grad)
print("params.grad = ", params.grad)
# Just eyeballing the above tgo make sure it looks reasonable.
def test_mutual_information_deriv(): def test_mutual_information_deriv():
""" Tests derivatives in randomized way """ """ Tests derivatives in randomized way """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment