Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
3fde3a89
"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "026991b152ffc3cbad8f49fe3f448ee66fe58803"
Commit
3fde3a89
authored
Nov 24, 2021
by
Daniel Povey
Browse files
Add joint version of MI recursion
parent
970fac7c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
145 additions
and
6 deletions
+145
-6
torch_mutual_information/__init__.py
torch_mutual_information/__init__.py
+1
-1
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+112
-3
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+32
-2
No files found.
torch_mutual_information/__init__.py
View file @
3fde3a89
from
.mutual_information
import
mutual_information_recursion
from
.mutual_information
import
mutual_information_recursion
,
joint_mutual_information_recursion
torch_mutual_information/mutual_information.py
View file @
3fde3a89
...
...
@@ -2,7 +2,7 @@ import os
import
torch
from
torch
import
Tensor
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
,
Sequence
from
torch.utils.cpp_extension
import
load
VERBOSE
=
False
...
...
@@ -168,7 +168,7 @@ def mutual_information_recursion(px, py, boundary=None):
respectively, and can be used if not all sequences are of the same length.
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutua
f
l
Returns a torch.Tensor of shape [B], containing the log of the mutual
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S+1,T+1]),
representing a mutual information between sub-sequences of lengths s and t:
...
...
@@ -198,5 +198,114 @@ def mutual_information_recursion(px, py, boundary=None):
# The following assertions are for efficiency
assert
px
.
stride
()[
-
1
]
==
1
assert
py
.
stride
()[
-
1
]
==
1
return
MutualInformationRecursionFunction
.
apply
(
px
,
py
,
boundary
)
def
_inner
(
a
:
Tensor
,
b
:
Tensor
)
->
Tensor
:
"""
Does inner product on the last dimension, with expected broadcasting, i.e. equivalent to
(a * b).sum(dim=-1)
without creating a large temporary.
"""
assert
a
.
shape
[
-
1
]
==
b
.
shape
[
-
1
]
# last last dim be K
a
=
a
.
unsqueeze
(
-
2
)
# (..., 1, K)
b
=
b
.
unsqueeze
(
-
1
)
# (..., K, 1)
c
=
torch
.
matmul
(
a
,
b
)
# (..., 1, 1)
return
c
.
squeeze
(
-
1
).
squeeze
(
-
1
)
def
joint_mutual_information_recursion
(
px
:
Sequence
[
Tensor
],
py
:
Sequence
[
Tensor
],
boundary
:
Optional
[
Tensor
]
=
None
)
->
Sequence
[
Tensor
]:
"""A recursion that is useful for modifications of RNN-T and similar loss functions,
where the recursion probabilities have a number of terms and you want them reported
separately. See mutual_information_recursion() for more documentation of the
basic aspects of this.
Args:
px: a sequence of Tensors, each of the same shape [B][S][T+1]
py: a sequence of Tensor, each of the same shape [B][S+1][T], the sequence must be
the same length as px.
boundary: optionally, a LongTensor of shape [B][4] containing rows
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
Returns:
a Tensor of shape (len(px), B),
whose sum over dim 0 is the total log-prob of the recursion mentioned below, per sequence.
The first element of the sequence of length len(px) is "special", in that it has an offset term
reflecting the difference between sum-of-log and log-of-sum; for more interpretable
loss values, the "main" part of your loss function should be first.
The recursion below applies if boundary == None, when it defaults
to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px and py:
p = tensor of shape (B, S+1, T+1), containing -infinity
p[b,0,0] = 0.0
# do the following in loop over s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
p[b,s,t-1] + py_sum[b,s,t-1])
(if s > 0 or t > 0)
return b[:][S][T]
This function lets you implement the above recursion efficiently, except
that it gives you a breakdown of the contribution from all the elements of
px and py separately. As noted above, the first element of the
sequence is "special".
"""
N
=
len
(
px
)
assert
len
(
py
)
==
N
and
N
>
0
B
,
S
,
T1
=
px
[
0
].
shape
T
=
T1
-
1
assert
py
[
0
].
shape
==
(
B
,
S
+
1
,
T
)
assert
px
[
0
].
dtype
==
py
[
0
].
dtype
px_cat
=
torch
.
stack
(
px
,
dim
=
0
)
# (N, B, S, T+1)
py_cat
=
torch
.
stack
(
py
,
dim
=
0
)
# (N, B, S+1, T)
px_tot
=
px_cat
.
sum
(
dim
=
0
)
# (B, S, T+1)
py_tot
=
py_cat
.
sum
(
dim
=
0
)
# (B, S+1, T)
if
boundary
is
not
None
:
assert
boundary
.
dtype
==
torch
.
int64
assert
boundary
.
shape
==
(
B
,
4
)
for
[
s_begin
,
t_begin
,
s_end
,
t_end
]
in
boundary
.
to
(
'cpu'
).
tolist
():
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
else
:
boundary
=
torch
.
zeros
(
0
,
0
,
dtype
=
torch
.
int64
,
device
=
px_tot
.
device
)
px_tot
,
py_tot
=
px_tot
.
contiguous
(),
py_tot
.
contiguous
()
# The following assertions are for efficiency
assert
px_tot
.
stride
()[
-
1
]
==
1
and
px_tot
.
ndim
==
3
assert
py_tot
.
stride
()[
-
1
]
==
1
and
py_tot
.
ndim
==
3
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px_tot
.
device
,
dtype
=
px_tot
.
dtype
)
# note, tot_probs is without grad.
tot_probs
=
_mutual_information_forward_dispatcher
(
px_tot
,
py_tot
,
boundary
,
p
)
# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad
=
torch
.
ones
(
B
,
device
=
px_tot
.
device
,
dtype
=
px_tot
.
dtype
)
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px_tot
,
py_tot
,
boundary
,
p
,
ans_grad
)
px_grad
,
py_grad
=
px_grad
.
reshape
(
1
,
B
,
-
1
),
py_grad
.
reshape
(
1
,
B
,
-
1
)
px_cat
,
py_cat
=
px_cat
.
reshape
(
N
,
B
,
-
1
),
py_cat
.
reshape
(
N
,
B
,
-
1
)
x_prods
=
_inner
(
px_grad
,
px_cat
)
# (N, B)
y_prods
=
_inner
(
py_grad
,
py_cat
)
# (N, B)
# If all the occupation counts were exactly 1.0 (i.e. no partial counts),
# "prods" should be equal to "tot_probs"; however, in general, "tot_probs"
# will be more positive due to the difference between log-of-sum and
# sum-of-log
prods
=
x_prods
+
y_prods
# (N, B)
with
torch
.
no_grad
():
offset
=
tot_probs
-
prods
.
sum
(
dim
=
0
)
# (B,)
prods
[
0
]
+=
offset
return
prods
# (N, B)
torch_mutual_information/mutual_information_test.py
View file @
3fde3a89
...
...
@@ -3,7 +3,7 @@
import
random
import
torch
from
torch_mutual_information
import
mutual_information_recursion
from
torch_mutual_information
import
mutual_information_recursion
,
joint_mutual_information_recursion
def
test_mutual_information_basic
():
...
...
@@ -73,9 +73,36 @@ def test_mutual_information_basic():
#m = mutual_information_recursion(px, py, None)
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
m2
=
joint_mutual_information_recursion
((
px
,),
(
py
,),
boundary
)
m3
=
joint_mutual_information_recursion
((
px
*
0.5
,
px
*
0.5
),
(
py
*
0.5
,
py
*
0.5
),
boundary
)
print
(
"m3, before sum, = "
,
m3
)
m3
=
m3
.
sum
(
dim
=
0
)
# it is supposed to be identical only after
# summing over dim 0, corresponding to the
# sequence dim
print
(
"m = "
,
m
,
", size = "
,
m
.
shape
)
print
(
"m2 = "
,
m2
,
", size = "
,
m2
.
shape
)
print
(
"m3 = "
,
m3
,
", size = "
,
m3
.
shape
)
assert
torch
.
allclose
(
m
,
m2
)
assert
torch
.
allclose
(
m
,
m3
)
#print("exp(m) = ", m.exp())
(
m
.
sum
()
*
3
).
backward
()
# the loop this is in checks that the CPU and CUDA versions give the same
# derivative; by randomizing which of m, m2 or m3 we backprop, we also
# ensure that the joint version of the code gives the same derivative
# as the regular version
scale
=
3
if
random
.
random
()
<
0.5
:
(
m
.
sum
()
*
scale
).
backward
()
elif
random
.
random
()
<
0.5
:
(
m2
.
sum
()
*
scale
).
backward
()
else
:
(
m3
.
sum
()
*
scale
).
backward
()
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
px_grads
.
append
(
px
.
grad
.
to
(
'cpu'
))
...
...
@@ -92,6 +119,9 @@ def test_mutual_information_basic():
assert
0
def
test_mutual_information_deriv
():
print
(
"Running test_mutual_information_deriv()"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment