Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
621d5fbb
Commit
621d5fbb
authored
Jul 26, 2021
by
Daniel Povey
Browse files
Initial work on interface code
parent
8e89c34b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
128 deletions
+102
-128
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+102
-128
No files found.
torch_mutual_information/mutual_information.py
View file @
621d5fbb
...
...
@@ -43,151 +43,125 @@ except ImportError:
def
_mutual_information_forward_dispatcher
(
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_mutual_information_forward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
,
q
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
input
,
params
.
contiguous
()
)
px
,
py
,
boundaries
,
q
)
else
:
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
input
,
params
)
px
,
py
,
boundaries
,
q
)
def
_mutual_information_backward_dispatcher
(
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
,
grad_output
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
input
.
is_cuda
:
def
_mutual_information_backward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
,
q
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
px
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
input
,
params
,
grad_output
))
px
,
py
,
boundaries
,
q
))
else
:
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
input
,
params
,
grad_output
))
px
,
py
,
boundaries
,
q
))
def
_reshape_as_3dim
(
x
:
torch
.
Tensor
,
dim
:
int
):
"""
Returns x reshaped so that dimension 'dim' is the middle of 3 dimensions,
combining dimensions and unsqueezing as needed. For example (writing
the behavior of this function as
input_shape, dim -> output_shape,
it will do:
(3), 0 -> (1, 3, 1)
(2, 5, 9), 1 -> (2, 5, 9)
(2, 5, 9), 2 -> (10, 9, 1)
(3, 4, 5, 6) -> (12, 5, 6)
The idea is to normalize the shape so the channel dimension is the middle
of 3, so the implementation can deal with a fixed layout.
Args:
x: tensor to be reshaped
dim: Dimension of x that is to be the middle of 3 dimensions in the result.
If negative, interpreted as an offset from x.dim.
"""
if
dim
<
0
:
dim
+=
input
.
ndim
orig_shape
=
list
(
x
.
shape
)
# `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
# is dimension/axis 1. We do this not by transposing, but by combining
# adjacent dims.
a
,
b
=
1
,
1
for
i
in
range
(
0
,
dim
):
a
*=
orig_shape
[
i
]
for
i
in
range
(
dim
+
1
,
len
(
orig_shape
)):
b
*=
orig_shape
[
i
]
new_shape
=
(
a
,
orig_shape
[
dim
],
b
)
return
x
.
reshape
(
new_shape
)
# `reshape` will make a contiguous copy if needed.
class
LearnedNonlinFunction
(
torch
.
autograd
.
Function
):
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
if
dim
<
0
:
dim
+=
input
.
ndim
assert
dim
>=
0
and
dim
<
input
.
ndim
assert
params
.
ndim
==
2
and
params
.
shape
[
1
]
%
2
==
1
assert
params
.
shape
[
0
]
==
input
.
shape
[
dim
]
ctx
.
dim
=
dim
ctx
.
save_for_backward
(
input
,
params
)
output
=
_mutual_information_forward_dispatcher
(
_reshape_as_3dim
(
input
,
dim
),
params
)
return
output
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
(
B
,
S
,
T
)
=
px
.
shape
# q is a rearrangement of a tensor p which is of shape (B,S,T),
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
# representation is that each row of q depends only on the previous row,
# so we can access the rows sequenctially and this leads to
# better memory access patterns. We are assuming that most likely
# T < S, which means that q should not require much more memory than p.
#
# Actually we access q beginning from 0 indexes even if `boundaries`
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# We don't need q if we are not going to do backprop
q
=
(
torch
.
empty
(
B
,
S
+
T
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
if
px
.
requires_grad
or
py
.
requires_grad
else
None
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
q
)
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
w
)
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
,
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
(
input
,
params
)
=
ctx
.
saved_tensors
orig_shape
=
input
.
shape
# We re-do the reshaping in the backward, rather than save the reshaped
# input, so that if this reshaping results in a copy it is not retained
# (this saves memory at the expense of a little extra work in such
# situations).
grad_input
,
grad_params
=
_mutual_information_backward_dispatcher
(
_reshape_as_3dim
(
input
,
ctx
.
dim
),
params
,
grad_output
)
return
grad_input
.
reshape
(
input
.
shape
),
grad_params
,
None
def
mutual_information
(
input
,
params
,
dim
):
"""Learned nonlinearity.
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
(
px
,
py
,
boundaries
,
q
)
=
ctx
.
saved_tensors
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundaries
,
q
)
return
(
px_grad
,
py_grad
,
None
)
def
mutual_information_recursion
(
input
,
px
,
py
,
boundaries
=
None
):
"""A recursion that is useful in computing mutual information between two sequences of
real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
the arguments are definitions that would be used when computing this type of
mutual information, but you can also view them as arbitrary quantities and just
look at the formula computed by this function.
Args:
input: The input, to be transformed pointwise; may be of any shape.
params: The parameters of the learned nonlinearity. Interpreted
as of shape (C, N + 1), where C is the channel and N, which
must be a power of 2 more than 1, is the number of linear regions in the
piecewise linear function. The first element is the log
of the distance between the discontinuities, and the
remaining elements are the derivatives of the function
in the linear pieces. We can explain what this function
is as follows:
Let the row of `params` for a particular channel be
interpreted as (l, d0, d1, d2 ... ). Let K = N/2, and L = exp(l).
Then the discontinuities in the function are at:
L * ( -K+1, -K+2, .., -1, 0, 1, .. K-1 )
and the values d0, d1 .. are interpreted as the slopes of the
function in the intervals, respectively:
[-inf.. L*(-K+1)), [L*-K+1..L*-K+2], ...
and we use these together with the assumption that the
function's value at x=0 is 0, to compute the function's value.
In terms of concrete calculations, we do it as follows:
Firstly, we can get rid of the factor of L by treating the l
parameter as a scale on the input and output, i.e.:
x = input * exp(-l)
... do the calculation y = f(xwithout a scale, interpreting the
discontinuities as being at integer values -K+1, -K+2, ... K+1,
and then:
output = y * = output * exp(l)
The core computation requires computing the y-values at the
discontinuities at -K+1, -K+2 and so on. Each one equals
the sign of the offset (- for negative K) times the sum
of the derivatives 'd' for the regions between the current
points and zero. If we number these as offsets o0, o1 and
so on up to N-2, then the formula is:
for o_n with n < K, o_N = -sum(k = n+1..K-1) d_k
for o_n with n >= k, o_N = sum(K..n-1) d_k
e.g. if K=3 and (d0, d1, d2, d3, d4, d5) = (1, 2, 1, 2, 1, 1), then:
o_0 = -(d1+d2) = -3 # x=-2 maps to y=-3
o_1 = -(d2) = -2 # x=-1 maps to y=-2
o_2 = () = 0 # x=0 maps to y=0
o_3 = (d3) = 2 # x=1 maps to y=2
o_4 = (d3 + d4) = 3 # x=2 maps to y=3
dim: The dimension of `input` that corresponds to the channel. It is
recommended that the channel should not be the fastest-varying
dimension (the one with stride=1), because this will make
the data loads and stores be non-coalesced and the kernels
will be quite slow.
Return: output, of the same shape as `input`.
px: A torch.Tensor of some floating point type, with shape [B][S][T],
where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of
EOS symbols but not BOS symbols). In the mutual information application,
px[b][s][t] would represent the following log odds ratio; ignoring
the b index on the right to make the notation more compact,
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an x value as opposed to a y value. In
practice it might be computed as a + b, where a is the log
probability of choosing to extend the sequence of length (s,t)
with an x as opposed to a y value; and b might in practice be
of the form:
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where N is the number of terms that the sum over t' included, which
might include some or all of the other sequences as well as this one.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference
is that the implementation assumes for optimization purposes that y
is likely to be the shorter sequence, i.e. that "most of the time T < S",
and it will be faster if you respect this.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S,T]),
representing a mutual information between sub-sequences of lengths s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s,t], p[b,s,t-1] + py[b,s,t])
where in the case where boundaries==None: the edge cases are handled
by treating p[b,-1,-1] as 0 and all other quantities with negative
indexes as -infinity; and ans[b] would equal p[S-1,T-1]. The extension to
cases where the boundaries are specified should be obvious.
"""
return
LearnedNonlinFunction
.
apply
(
x
,
params
,
dim
)
assert
px
.
ndim
==
3
and
px
.
shape
==
py
.
shape
and
px
.
dtype
==
py
.
dtype
(
B
,
S
,
T
)
=
px
.
shape
if
boundaries
is
not
None
:
assert
boundaries
.
dtype
==
torch
.
LongTensor
assert
boundaries
.
shape
==
(
B
,
4
)
return
MutualInformationRecursion
.
apply
(
px
,
py
,
boundaries
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment