Commit 621d5fbb authored by Daniel Povey's avatar Daniel Povey
Browse files

Initial work on interface code

parent 8e89c34b
......@@ -43,151 +43,125 @@ except ImportError:
def _mutual_information_forward_dispatcher(input: torch.Tensor,
params: torch.Tensor) -> torch.Tensor:
def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
if input.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_mutual_information_cuda.mutual_information_cuda(
input, params.contiguous())
px, py, boundaries, q)
else:
return torch_mutual_information_cpu.mutual_information_cpu(
input, params)
px, py, boundaries, q)
def _mutual_information_backward_dispatcher(input: torch.Tensor,
params: torch.Tensor,
grad_output) -> Tuple[torch.Tensor, torch.Tensor]:
if input.is_cuda:
def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if px.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
input, params,
grad_output))
px, py, boundaries, q))
else:
return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu(
input, params, grad_output))
px, py, boundaries, q))
def _reshape_as_3dim(x: torch.Tensor, dim: int):
"""
Returns x reshaped so that dimension 'dim' is the middle of 3 dimensions,
combining dimensions and unsqueezing as needed. For example (writing
the behavior of this function as
input_shape, dim -> output_shape,
it will do:
(3), 0 -> (1, 3, 1)
(2, 5, 9), 1 -> (2, 5, 9)
(2, 5, 9), 2 -> (10, 9, 1)
(3, 4, 5, 6) -> (12, 5, 6)
The idea is to normalize the shape so the channel dimension is the middle
of 3, so the implementation can deal with a fixed layout.
Args:
x: tensor to be reshaped
dim: Dimension of x that is to be the middle of 3 dimensions in the result.
If negative, interpreted as an offset from x.dim.
"""
if dim < 0:
dim += input.ndim
orig_shape = list(x.shape)
# `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
# is dimension/axis 1. We do this not by transposing, but by combining
# adjacent dims.
a, b = 1, 1
for i in range(0, dim):
a *= orig_shape[i]
for i in range(dim + 1, len(orig_shape)):
b *= orig_shape[i]
new_shape = (a, orig_shape[dim], b)
return x.reshape(new_shape) # `reshape` will make a contiguous copy if needed.
class LearnedNonlinFunction(torch.autograd.Function):
class MutualInformationRecursionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, params: torch.Tensor, dim: int) -> torch.Tensor:
if dim < 0:
dim += input.ndim
assert dim >= 0 and dim < input.ndim
assert params.ndim == 2 and params.shape[1] % 2 == 1
assert params.shape[0] == input.shape[dim]
ctx.dim = dim
ctx.save_for_backward(input, params)
output = _mutual_information_forward_dispatcher(_reshape_as_3dim(input, dim),
params)
return output
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: torch.Tensor) -> torch.Tensor:
(B, S, T) = px.shape
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
# so we can access the rows sequenctially and this leads to
# better memory access patterns. We are assuming that most likely
# T < S, which means that q should not require much more memory than p.
#
# Actually we access q beginning from 0 indexes even if `boundaries`
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# We don't need q if we are not going to do backprop
q = (torch.empty(B, S + T, device=px.device, dtype=px.dtype)
if px.requires_grad or py.requires_grad
else None)
ans = _mutual_information_forward_dispatcher(px, py, boundaries, q)
if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, w)
@staticmethod
def backward(ctx, grad_output: torch.Tensor, None) -> Tuple[torch.Tensor, torch.Tensor, None]:
(input, params) = ctx.saved_tensors
orig_shape = input.shape
# We re-do the reshaping in the backward, rather than save the reshaped
# input, so that if this reshaping results in a copy it is not retained
# (this saves memory at the expense of a little extra work in such
# situations).
grad_input, grad_params = _mutual_information_backward_dispatcher(
_reshape_as_3dim(input, ctx.dim), params, grad_output)
return grad_input.reshape(input.shape), grad_params, None
def mutual_information(input, params, dim):
"""Learned nonlinearity.
def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
(px, py, boundaries, q) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, q)
return (px_grad, py_grad, None)
def mutual_information_recursion(input, px, py, boundaries=None):
"""A recursion that is useful in computing mutual information between two sequences of
real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
the arguments are definitions that would be used when computing this type of
mutual information, but you can also view them as arbitrary quantities and just
look at the formula computed by this function.
Args:
input: The input, to be transformed pointwise; may be of any shape.
params: The parameters of the learned nonlinearity. Interpreted
as of shape (C, N + 1), where C is the channel and N, which
must be a power of 2 more than 1, is the number of linear regions in the
piecewise linear function. The first element is the log
of the distance between the discontinuities, and the
remaining elements are the derivatives of the function
in the linear pieces. We can explain what this function
is as follows:
Let the row of `params` for a particular channel be
interpreted as (l, d0, d1, d2 ... ). Let K = N/2, and L = exp(l).
Then the discontinuities in the function are at:
L * ( -K+1, -K+2, .., -1, 0, 1, .. K-1 )
and the values d0, d1 .. are interpreted as the slopes of the
function in the intervals, respectively:
[-inf.. L*(-K+1)), [L*-K+1..L*-K+2], ...
and we use these together with the assumption that the
function's value at x=0 is 0, to compute the function's value.
In terms of concrete calculations, we do it as follows:
Firstly, we can get rid of the factor of L by treating the l
parameter as a scale on the input and output, i.e.:
x = input * exp(-l)
... do the calculation y = f(xwithout a scale, interpreting the
discontinuities as being at integer values -K+1, -K+2, ... K+1,
and then:
output = y * = output * exp(l)
The core computation requires computing the y-values at the
discontinuities at -K+1, -K+2 and so on. Each one equals
the sign of the offset (- for negative K) times the sum
of the derivatives 'd' for the regions between the current
points and zero. If we number these as offsets o0, o1 and
so on up to N-2, then the formula is:
for o_n with n < K, o_N = -sum(k = n+1..K-1) d_k
for o_n with n >= k, o_N = sum(K..n-1) d_k
e.g. if K=3 and (d0, d1, d2, d3, d4, d5) = (1, 2, 1, 2, 1, 1), then:
o_0 = -(d1+d2) = -3 # x=-2 maps to y=-3
o_1 = -(d2) = -2 # x=-1 maps to y=-2
o_2 = () = 0 # x=0 maps to y=0
o_3 = (d3) = 2 # x=1 maps to y=2
o_4 = (d3 + d4) = 3 # x=2 maps to y=3
dim: The dimension of `input` that corresponds to the channel. It is
recommended that the channel should not be the fastest-varying
dimension (the one with stride=1), because this will make
the data loads and stores be non-coalesced and the kernels
will be quite slow.
Return: output, of the same shape as `input`.
px: A torch.Tensor of some floating point type, with shape [B][S][T],
where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of
EOS symbols but not BOS symbols). In the mutual information application,
px[b][s][t] would represent the following log odds ratio; ignoring
the b index on the right to make the notation more compact,
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an x value as opposed to a y value. In
practice it might be computed as a + b, where a is the log
probability of choosing to extend the sequence of length (s,t)
with an x as opposed to a y value; and b might in practice be
of the form:
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where N is the number of terms that the sum over t' included, which
might include some or all of the other sequences as well as this one.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference
is that the implementation assumes for optimization purposes that y
is likely to be the shorter sequence, i.e. that "most of the time T < S",
and it will be faster if you respect this.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S,T]),
representing a mutual information between sub-sequences of lengths s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s,t], p[b,s,t-1] + py[b,s,t])
where in the case where boundaries==None: the edge cases are handled
by treating p[b,-1,-1] as 0 and all other quantities with negative
indexes as -infinity; and ans[b] would equal p[S-1,T-1]. The extension to
cases where the boundaries are specified should be obvious.
"""
return LearnedNonlinFunction.apply(x, params, dim)
assert px.ndim == 3 and px.shape == py.shape and px.dtype == py.dtype
(B, S, T) = px.shape
if boundaries is not None:
assert boundaries.dtype == torch.LongTensor
assert boundaries.shape == (B, 4)
return MutualInformationRecursion.apply(px, py, boundaries)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment