Commit 621d5fbb authored by Daniel Povey's avatar Daniel Povey
Browse files

Initial work on interface code

parent 8e89c34b
...@@ -43,151 +43,125 @@ except ImportError: ...@@ -43,151 +43,125 @@ except ImportError:
def _mutual_information_forward_dispatcher(input: torch.Tensor, def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor,
params: torch.Tensor) -> torch.Tensor: boundaries: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
if input.is_cuda: if input.is_cuda:
if torch_mutual_information_cuda is None: if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module') raise EnvironmentError(f'Failed to load native CUDA module')
return torch_mutual_information_cuda.mutual_information_cuda( return torch_mutual_information_cuda.mutual_information_cuda(
input, params.contiguous()) px, py, boundaries, q)
else: else:
return torch_mutual_information_cpu.mutual_information_cpu( return torch_mutual_information_cpu.mutual_information_cpu(
input, params) px, py, boundaries, q)
def _mutual_information_backward_dispatcher(input: torch.Tensor, def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
params: torch.Tensor, boundaries: torch.Tensor, q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
grad_output) -> Tuple[torch.Tensor, torch.Tensor]: if px.is_cuda:
if input.is_cuda:
if torch_mutual_information_cuda is None: if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module') raise EnvironmentError(f'Failed to load native CUDA module')
return tuple(torch_mutual_information_cuda.mutual_information_backward_cuda( return tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
input, params, px, py, boundaries, q))
grad_output))
else: else:
return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu( return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu(
input, params, grad_output)) px, py, boundaries, q))
def _reshape_as_3dim(x: torch.Tensor, dim: int):
"""
Returns x reshaped so that dimension 'dim' is the middle of 3 dimensions,
combining dimensions and unsqueezing as needed. For example (writing
the behavior of this function as
input_shape, dim -> output_shape,
it will do:
(3), 0 -> (1, 3, 1)
(2, 5, 9), 1 -> (2, 5, 9)
(2, 5, 9), 2 -> (10, 9, 1)
(3, 4, 5, 6) -> (12, 5, 6)
The idea is to normalize the shape so the channel dimension is the middle
of 3, so the implementation can deal with a fixed layout.
Args: class MutualInformationRecursionFunction(torch.autograd.Function):
x: tensor to be reshaped
dim: Dimension of x that is to be the middle of 3 dimensions in the result.
If negative, interpreted as an offset from x.dim.
"""
if dim < 0:
dim += input.ndim
orig_shape = list(x.shape)
# `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
# is dimension/axis 1. We do this not by transposing, but by combining
# adjacent dims.
a, b = 1, 1
for i in range(0, dim):
a *= orig_shape[i]
for i in range(dim + 1, len(orig_shape)):
b *= orig_shape[i]
new_shape = (a, orig_shape[dim], b)
return x.reshape(new_shape) # `reshape` will make a contiguous copy if needed.
class LearnedNonlinFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input: torch.Tensor, params: torch.Tensor, dim: int) -> torch.Tensor: def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: torch.Tensor) -> torch.Tensor:
if dim < 0: (B, S, T) = px.shape
dim += input.ndim
assert dim >= 0 and dim < input.ndim # q is a rearrangement of a tensor p which is of shape (B,S,T),
assert params.ndim == 2 and params.shape[1] % 2 == 1 # using p[b,s,t] == q[b,s+t,t]. The reason for working with this
assert params.shape[0] == input.shape[dim] # representation is that each row of q depends only on the previous row,
# so we can access the rows sequenctially and this leads to
ctx.dim = dim # better memory access patterns. We are assuming that most likely
ctx.save_for_backward(input, params) # T < S, which means that q should not require much more memory than p.
output = _mutual_information_forward_dispatcher(_reshape_as_3dim(input, dim), #
params) # Actually we access q beginning from 0 indexes even if `boundaries`
return output # has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# We don't need q if we are not going to do backprop
q = (torch.empty(B, S + T, device=px.device, dtype=px.dtype)
if px.requires_grad or py.requires_grad
else None)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, q)
if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, w)
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor, None) -> Tuple[torch.Tensor, torch.Tensor, None]: def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
(input, params) = ctx.saved_tensors (px, py, boundaries, q) = ctx.saved_tensors
orig_shape = input.shape (px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, q)
# We re-do the reshaping in the backward, rather than save the reshaped return (px_grad, py_grad, None)
# input, so that if this reshaping results in a copy it is not retained
# (this saves memory at the expense of a little extra work in such
# situations).
grad_input, grad_params = _mutual_information_backward_dispatcher( def mutual_information_recursion(input, px, py, boundaries=None):
_reshape_as_3dim(input, ctx.dim), params, grad_output) """A recursion that is useful in computing mutual information between two sequences of
return grad_input.reshape(input.shape), grad_params, None real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
the arguments are definitions that would be used when computing this type of
def mutual_information(input, params, dim): mutual information, but you can also view them as arbitrary quantities and just
"""Learned nonlinearity. look at the formula computed by this function.
Args: Args:
input: The input, to be transformed pointwise; may be of any shape. px: A torch.Tensor of some floating point type, with shape [B][S][T],
where B is the batch size, S is the length of the 'x' sequence
params: The parameters of the learned nonlinearity. Interpreted (including representations of EOS symbols but not BOS symbols), and S is the
as of shape (C, N + 1), where C is the channel and N, which length of the 'y' sequence (including representations of
must be a power of 2 more than 1, is the number of linear regions in the EOS symbols but not BOS symbols). In the mutual information application,
piecewise linear function. The first element is the log px[b][s][t] would represent the following log odds ratio; ignoring
of the distance between the discontinuities, and the the b index on the right to make the notation more compact,
remaining elements are the derivatives of the function
in the linear pieces. We can explain what this function px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
is as follows:
Let the row of `params` for a particular channel be This expression also implicitly includes the log-probability of
interpreted as (l, d0, d1, d2 ... ). Let K = N/2, and L = exp(l). choosing to generate an x value as opposed to a y value. In
Then the discontinuities in the function are at: practice it might be computed as a + b, where a is the log
L * ( -K+1, -K+2, .., -1, 0, 1, .. K-1 ) probability of choosing to extend the sequence of length (s,t)
and the values d0, d1 .. are interpreted as the slopes of the with an x as opposed to a y value; and b might in practice be
function in the intervals, respectively: of the form:
[-inf.. L*(-K+1)), [L*-K+1..L*-K+2], ... log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
and we use these together with the assumption that the where N is the number of terms that the sum over t' included, which
function's value at x=0 is 0, to compute the function's value. might include some or all of the other sequences as well as this one.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
In terms of concrete calculations, we do it as follows: representing
Firstly, we can get rid of the factor of L by treating the l py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
parameter as a scale on the input and output, i.e.: This function does not treat x and y differently; the only difference
x = input * exp(-l) is that the implementation assumes for optimization purposes that y
... do the calculation y = f(xwithout a scale, interpreting the is likely to be the shorter sequence, i.e. that "most of the time T < S",
discontinuities as being at integer values -K+1, -K+2, ... K+1, and it will be faster if you respect this.
and then: boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
output = y * = output * exp(l) [s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
The core computation requires computing the y-values at the one-past-the-last positions in the x and y sequences
discontinuities at -K+1, -K+2 and so on. Each one equals respectively, and can be used if not all sequences are of the same length.
the sign of the offset (- for negative K) times the sum
of the derivatives 'd' for the regions between the current Returns:
points and zero. If we number these as offsets o0, o1 and Returns a torch.Tensor of shape [B], containing the log of the mutuafl
so on up to N-2, then the formula is: information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S,T]),
for o_n with n < K, o_N = -sum(k = n+1..K-1) d_k representing a mutual information between sub-sequences of lengths s and t:
for o_n with n >= k, o_N = sum(K..n-1) d_k
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s,t], p[b,s,t-1] + py[b,s,t])
e.g. if K=3 and (d0, d1, d2, d3, d4, d5) = (1, 2, 1, 2, 1, 1), then:
where in the case where boundaries==None: the edge cases are handled
o_0 = -(d1+d2) = -3 # x=-2 maps to y=-3 by treating p[b,-1,-1] as 0 and all other quantities with negative
o_1 = -(d2) = -2 # x=-1 maps to y=-2 indexes as -infinity; and ans[b] would equal p[S-1,T-1]. The extension to
o_2 = () = 0 # x=0 maps to y=0 cases where the boundaries are specified should be obvious.
o_3 = (d3) = 2 # x=1 maps to y=2
o_4 = (d3 + d4) = 3 # x=2 maps to y=3
dim: The dimension of `input` that corresponds to the channel. It is
recommended that the channel should not be the fastest-varying
dimension (the one with stride=1), because this will make
the data loads and stores be non-coalesced and the kernels
will be quite slow.
Return: output, of the same shape as `input`.
""" """
return LearnedNonlinFunction.apply(x, params, dim) assert px.ndim == 3 and px.shape == py.shape and px.dtype == py.dtype
(B, S, T) = px.shape
if boundaries is not None:
assert boundaries.dtype == torch.LongTensor
assert boundaries.shape == (B, 4)
return MutualInformationRecursion.apply(px, py, boundaries)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment