"examples/vscode:/vscode.git/clone" did not exist on "972a9f1323812811cf2478155a861d6aaade036b"
Commit 3fde3a89 authored by Daniel Povey's avatar Daniel Povey
Browse files

Add joint version of MI recursion

parent 970fac7c
from .mutual_information import mutual_information_recursion
from .mutual_information import mutual_information_recursion, joint_mutual_information_recursion
......@@ -2,7 +2,7 @@ import os
import torch
from torch import Tensor
from typing import Tuple, Optional
from typing import Tuple, Optional, Sequence
from torch.utils.cpp_extension import load
VERBOSE = False
......@@ -168,7 +168,7 @@ def mutual_information_recursion(px, py, boundary=None):
respectively, and can be used if not all sequences are of the same length.
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutuafl
Returns a torch.Tensor of shape [B], containing the log of the mutual
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S+1,T+1]),
representing a mutual information between sub-sequences of lengths s and t:
......@@ -198,5 +198,114 @@ def mutual_information_recursion(px, py, boundary=None):
# The following assertions are for efficiency
assert px.stride()[-1] == 1
assert py.stride()[-1] == 1
return MutualInformationRecursionFunction.apply(px, py, boundary)
def _inner(a: Tensor, b: Tensor) -> Tensor:
"""
Does inner product on the last dimension, with expected broadcasting, i.e. equivalent to
(a * b).sum(dim=-1)
without creating a large temporary.
"""
assert a.shape[-1] == b.shape[-1] # last last dim be K
a = a.unsqueeze(-2) # (..., 1, K)
b = b.unsqueeze(-1) # (..., K, 1)
c = torch.matmul(a, b) # (..., 1, 1)
return c.squeeze(-1).squeeze(-1)
def joint_mutual_information_recursion(px: Sequence[Tensor], py: Sequence[Tensor],
boundary: Optional[Tensor] = None) -> Sequence[Tensor]:
"""A recursion that is useful for modifications of RNN-T and similar loss functions,
where the recursion probabilities have a number of terms and you want them reported
separately. See mutual_information_recursion() for more documentation of the
basic aspects of this.
Args:
px: a sequence of Tensors, each of the same shape [B][S][T+1]
py: a sequence of Tensor, each of the same shape [B][S+1][T], the sequence must be
the same length as px.
boundary: optionally, a LongTensor of shape [B][4] containing rows
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
Returns:
a Tensor of shape (len(px), B),
whose sum over dim 0 is the total log-prob of the recursion mentioned below, per sequence.
The first element of the sequence of length len(px) is "special", in that it has an offset term
reflecting the difference between sum-of-log and log-of-sum; for more interpretable
loss values, the "main" part of your loss function should be first.
The recursion below applies if boundary == None, when it defaults
to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px and py:
p = tensor of shape (B, S+1, T+1), containing -infinity
p[b,0,0] = 0.0
# do the following in loop over s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
p[b,s,t-1] + py_sum[b,s,t-1])
(if s > 0 or t > 0)
return b[:][S][T]
This function lets you implement the above recursion efficiently, except
that it gives you a breakdown of the contribution from all the elements of
px and py separately. As noted above, the first element of the
sequence is "special".
"""
N = len(px)
assert len(py) == N and N > 0
B, S, T1 = px[0].shape
T = T1 - 1
assert py[0].shape == (B, S + 1, T)
assert px[0].dtype == py[0].dtype
px_cat = torch.stack(px, dim=0) # (N, B, S, T+1)
py_cat = torch.stack(py, dim=0) # (N, B, S+1, T)
px_tot = px_cat.sum(dim=0) # (B, S, T+1)
py_tot = py_cat.sum(dim=0) # (B, S+1, T)
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for [ s_begin, t_begin, s_end, t_end ] in boundary.to('cpu').tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
else:
boundary = torch.zeros(0, 0, dtype=torch.int64, device=px_tot.device)
px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
# The following assertions are for efficiency
assert px_tot.stride()[-1] == 1 and px_tot.ndim == 3
assert py_tot.stride()[-1] == 1 and py_tot.ndim == 3
p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)
# note, tot_probs is without grad.
tot_probs = _mutual_information_forward_dispatcher(px_tot, py_tot, boundary, p)
# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px_tot, py_tot,
boundary, p, ans_grad)
px_grad, py_grad = px_grad.reshape(1, B, -1), py_grad.reshape(1, B, -1)
px_cat, py_cat = px_cat.reshape(N, B, -1), py_cat.reshape(N, B, -1)
x_prods = _inner(px_grad, px_cat) # (N, B)
y_prods = _inner(py_grad, py_cat) # (N, B)
# If all the occupation counts were exactly 1.0 (i.e. no partial counts),
# "prods" should be equal to "tot_probs"; however, in general, "tot_probs"
# will be more positive due to the difference between log-of-sum and
# sum-of-log
prods = x_prods + y_prods # (N, B)
with torch.no_grad():
offset = tot_probs - prods.sum(dim=0) # (B,)
prods[0] += offset
return prods # (N, B)
......@@ -3,7 +3,7 @@
import random
import torch
from torch_mutual_information import mutual_information_recursion
from torch_mutual_information import mutual_information_recursion, joint_mutual_information_recursion
def test_mutual_information_basic():
......@@ -73,9 +73,36 @@ def test_mutual_information_basic():
#m = mutual_information_recursion(px, py, None)
m = mutual_information_recursion(px, py, boundary)
m2 = joint_mutual_information_recursion((px,), (py,), boundary)
m3 = joint_mutual_information_recursion((px * 0.5, px * 0.5), (py * 0.5, py * 0.5), boundary)
print("m3, before sum, = ", m3)
m3 = m3.sum(dim=0) # it is supposed to be identical only after
# summing over dim 0, corresponding to the
# sequence dim
print("m = ", m, ", size = ", m.shape)
print("m2 = ", m2, ", size = ", m2.shape)
print("m3 = ", m3, ", size = ", m3.shape)
assert torch.allclose(m, m2)
assert torch.allclose(m, m3)
#print("exp(m) = ", m.exp())
(m.sum() * 3).backward()
# the loop this is in checks that the CPU and CUDA versions give the same
# derivative; by randomizing which of m, m2 or m3 we backprop, we also
# ensure that the joint version of the code gives the same derivative
# as the regular version
scale = 3
if random.random() < 0.5:
(m.sum() * scale).backward()
elif random.random() < 0.5:
(m2.sum() * scale).backward()
else:
(m3.sum() * scale).backward()
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
px_grads.append(px.grad.to('cpu'))
......@@ -92,6 +119,9 @@ def test_mutual_information_basic():
assert 0
def test_mutual_information_deriv():
print("Running test_mutual_information_deriv()")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment