Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
621d5fbb
Commit
621d5fbb
authored
Jul 26, 2021
by
Daniel Povey
Browse files
Initial work on interface code
parent
8e89c34b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
128 deletions
+102
-128
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+102
-128
No files found.
torch_mutual_information/mutual_information.py
View file @
621d5fbb
...
@@ -43,151 +43,125 @@ except ImportError:
...
@@ -43,151 +43,125 @@ except ImportError:
def
_mutual_information_forward_dispatcher
(
input
:
torch
.
Tensor
,
def
_mutual_information_forward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
params
:
torch
.
Tensor
)
->
torch
.
Tensor
:
boundaries
:
torch
.
Tensor
,
q
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
if
input
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
input
,
params
.
contiguous
()
)
px
,
py
,
boundaries
,
q
)
else
:
else
:
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
input
,
params
)
px
,
py
,
boundaries
,
q
)
def
_mutual_information_backward_dispatcher
(
input
:
torch
.
Tensor
,
def
_mutual_information_backward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
params
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
,
q
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
grad_output
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
px
.
is_cuda
:
if
input
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
return
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
input
,
params
,
px
,
py
,
boundaries
,
q
))
grad_output
))
else
:
else
:
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
input
,
params
,
grad_output
))
px
,
py
,
boundaries
,
q
))
def
_reshape_as_3dim
(
x
:
torch
.
Tensor
,
dim
:
int
):
"""
Returns x reshaped so that dimension 'dim' is the middle of 3 dimensions,
combining dimensions and unsqueezing as needed. For example (writing
the behavior of this function as
input_shape, dim -> output_shape,
it will do:
(3), 0 -> (1, 3, 1)
(2, 5, 9), 1 -> (2, 5, 9)
(2, 5, 9), 2 -> (10, 9, 1)
(3, 4, 5, 6) -> (12, 5, 6)
The idea is to normalize the shape so the channel dimension is the middle
of 3, so the implementation can deal with a fixed layout.
Args:
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
x: tensor to be reshaped
dim: Dimension of x that is to be the middle of 3 dimensions in the result.
If negative, interpreted as an offset from x.dim.
"""
if
dim
<
0
:
dim
+=
input
.
ndim
orig_shape
=
list
(
x
.
shape
)
# `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
# is dimension/axis 1. We do this not by transposing, but by combining
# adjacent dims.
a
,
b
=
1
,
1
for
i
in
range
(
0
,
dim
):
a
*=
orig_shape
[
i
]
for
i
in
range
(
dim
+
1
,
len
(
orig_shape
)):
b
*=
orig_shape
[
i
]
new_shape
=
(
a
,
orig_shape
[
dim
],
b
)
return
x
.
reshape
(
new_shape
)
# `reshape` will make a contiguous copy if needed.
class
LearnedNonlinFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundaries
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
dim
<
0
:
(
B
,
S
,
T
)
=
px
.
shape
dim
+=
input
.
ndim
assert
dim
>=
0
and
dim
<
input
.
ndim
# q is a rearrangement of a tensor p which is of shape (B,S,T),
assert
params
.
ndim
==
2
and
params
.
shape
[
1
]
%
2
==
1
# using p[b,s,t] == q[b,s+t,t]. The reason for working with this
assert
params
.
shape
[
0
]
==
input
.
shape
[
dim
]
# representation is that each row of q depends only on the previous row,
# so we can access the rows sequenctially and this leads to
ctx
.
dim
=
dim
# better memory access patterns. We are assuming that most likely
ctx
.
save_for_backward
(
input
,
params
)
# T < S, which means that q should not require much more memory than p.
output
=
_mutual_information_forward_dispatcher
(
_reshape_as_3dim
(
input
,
dim
),
#
params
)
# Actually we access q beginning from 0 indexes even if `boundaries`
return
output
# has t_begin > 0 or s_begin > 0, i.e. we really access q as
# q[b, s-s_begin + t-t_begin, t-t_begin];
# note, rows of `boundaries` are [s_begin, t_begin, s_end, t_end].
# We don't need q if we are not going to do backprop
q
=
(
torch
.
empty
(
B
,
S
+
T
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
if
px
.
requires_grad
or
py
.
requires_grad
else
None
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundaries
,
q
)
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundaries
,
w
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
,
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
(
input
,
params
)
=
ctx
.
saved_tensors
(
px
,
py
,
boundaries
,
q
)
=
ctx
.
saved_tensors
orig_shape
=
input
.
shape
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundaries
,
q
)
# We re-do the reshaping in the backward, rather than save the reshaped
return
(
px_grad
,
py_grad
,
None
)
# input, so that if this reshaping results in a copy it is not retained
# (this saves memory at the expense of a little extra work in such
# situations).
grad_input
,
grad_params
=
_mutual_information_backward_dispatcher
(
def
mutual_information_recursion
(
input
,
px
,
py
,
boundaries
=
None
):
_reshape_as_3dim
(
input
,
ctx
.
dim
),
params
,
grad_output
)
"""A recursion that is useful in computing mutual information between two sequences of
return
grad_input
.
reshape
(
input
.
shape
),
grad_params
,
None
real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
the arguments are definitions that would be used when computing this type of
def
mutual_information
(
input
,
params
,
dim
):
mutual information, but you can also view them as arbitrary quantities and just
"""Learned nonlinearity.
look at the formula computed by this function.
Args:
Args:
input: The input, to be transformed pointwise; may be of any shape.
px: A torch.Tensor of some floating point type, with shape [B][S][T],
where B is the batch size, S is the length of the 'x' sequence
params: The parameters of the learned nonlinearity. Interpreted
(including representations of EOS symbols but not BOS symbols), and S is the
as of shape (C, N + 1), where C is the channel and N, which
length of the 'y' sequence (including representations of
must be a power of 2 more than 1, is the number of linear regions in the
EOS symbols but not BOS symbols). In the mutual information application,
piecewise linear function. The first element is the log
px[b][s][t] would represent the following log odds ratio; ignoring
of the distance between the discontinuities, and the
the b index on the right to make the notation more compact,
remaining elements are the derivatives of the function
in the linear pieces. We can explain what this function
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
is as follows:
Let the row of `params` for a particular channel be
This expression also implicitly includes the log-probability of
interpreted as (l, d0, d1, d2 ... ). Let K = N/2, and L = exp(l).
choosing to generate an x value as opposed to a y value. In
Then the discontinuities in the function are at:
practice it might be computed as a + b, where a is the log
L * ( -K+1, -K+2, .., -1, 0, 1, .. K-1 )
probability of choosing to extend the sequence of length (s,t)
and the values d0, d1 .. are interpreted as the slopes of the
with an x as opposed to a y value; and b might in practice be
function in the intervals, respectively:
of the form:
[-inf.. L*(-K+1)), [L*-K+1..L*-K+2], ...
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
and we use these together with the assumption that the
where N is the number of terms that the sum over t' included, which
function's value at x=0 is 0, to compute the function's value.
might include some or all of the other sequences as well as this one.
py: A torch.Tensor of the same dtype as px, with shape [B][S][T],
In terms of concrete calculations, we do it as follows:
representing
Firstly, we can get rid of the factor of L by treating the l
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
parameter as a scale on the input and output, i.e.:
This function does not treat x and y differently; the only difference
x = input * exp(-l)
is that the implementation assumes for optimization purposes that y
... do the calculation y = f(xwithout a scale, interpreting the
is likely to be the shorter sequence, i.e. that "most of the time T < S",
discontinuities as being at integer values -K+1, -K+2, ... K+1,
and it will be faster if you respect this.
and then:
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
output = y * = output * exp(l)
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
The core computation requires computing the y-values at the
one-past-the-last positions in the x and y sequences
discontinuities at -K+1, -K+2 and so on. Each one equals
respectively, and can be used if not all sequences are of the same length.
the sign of the offset (- for negative K) times the sum
of the derivatives 'd' for the regions between the current
Returns:
points and zero. If we number these as offsets o0, o1 and
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
so on up to N-2, then the formula is:
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S,T]),
for o_n with n < K, o_N = -sum(k = n+1..K-1) d_k
representing a mutual information between sub-sequences of lengths s and t:
for o_n with n >= k, o_N = sum(K..n-1) d_k
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s,t], p[b,s,t-1] + py[b,s,t])
e.g. if K=3 and (d0, d1, d2, d3, d4, d5) = (1, 2, 1, 2, 1, 1), then:
where in the case where boundaries==None: the edge cases are handled
o_0 = -(d1+d2) = -3 # x=-2 maps to y=-3
by treating p[b,-1,-1] as 0 and all other quantities with negative
o_1 = -(d2) = -2 # x=-1 maps to y=-2
indexes as -infinity; and ans[b] would equal p[S-1,T-1]. The extension to
o_2 = () = 0 # x=0 maps to y=0
cases where the boundaries are specified should be obvious.
o_3 = (d3) = 2 # x=1 maps to y=2
o_4 = (d3 + d4) = 3 # x=2 maps to y=3
dim: The dimension of `input` that corresponds to the channel. It is
recommended that the channel should not be the fastest-varying
dimension (the one with stride=1), because this will make
the data loads and stores be non-coalesced and the kernels
will be quite slow.
Return: output, of the same shape as `input`.
"""
"""
return
LearnedNonlinFunction
.
apply
(
x
,
params
,
dim
)
assert
px
.
ndim
==
3
and
px
.
shape
==
py
.
shape
and
px
.
dtype
==
py
.
dtype
(
B
,
S
,
T
)
=
px
.
shape
if
boundaries
is
not
None
:
assert
boundaries
.
dtype
==
torch
.
LongTensor
assert
boundaries
.
shape
==
(
B
,
4
)
return
MutualInformationRecursion
.
apply
(
px
,
py
,
boundaries
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment