Commit d53e923b authored by pkufool's avatar pkufool
Browse files

Move k2 rnnt_loss here

parent b5828e2b
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
PYBIND11_MODULE(_fast_rnnt, m) {
m.doc() = "Python wrapper for Mutual Information.";
m.def(
"mutual_information_forward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p) -> torch::Tensor {
if (px.device().is_cpu()) {
return fast_rnnt::MutualInformationCpu(px, py, boundary, p);
} else {
#ifdef FT_WITH_CUDA
return fast_rnnt::MutualInformationCuda(px, py, boundary, p);
#else
//K2_LOG(FATAL) << "Failed to find native CUDA module, make sure "
//<< "that you compiled the code with K2_WITH_CUDA.";
return torch::Tensor();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"));
m.def(
"mutual_information_backward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary, torch::Tensor p,
torch::Tensor ans_grad) -> std::vector<torch::Tensor> {
if (px.device().is_cpu()) {
return fast_rnnt::MutualInformationBackwardCpu(px, py, boundary, p,
ans_grad);
} else {
#ifdef FT_WITH_CUDA
return fast_rnnt::MutualInformationBackwardCuda(px, py, boundary, p,
ans_grad, true);
#else
//K2_LOG(FATAL) << "Failed to find native CUDA module, make sure "
//<< "that you compiled the code with K2_WITH_CUDA.";
return std::vector<torch::Tensor>();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"),
py::arg("ans_grad"));
}
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#include "pybind11/pybind11.h"
namespace py = pybind11;
#endif // FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
from .mutual_information import mutual_information_recursion
from .mutual_information import joint_mutual_information_recursion
from .rnnt_loss import do_rnnt_pruning
from .rnnt_loss import get_rnnt_logprobs
from .rnnt_loss import get_rnnt_logprobs_joint
from .rnnt_loss import get_rnnt_logprobs_pruned
from .rnnt_loss import get_rnnt_logprobs_smoothed
from .rnnt_loss import get_rnnt_prune_ranges
from .rnnt_loss import rnnt_loss
from .rnnt_loss import rnnt_loss_pruned
from .rnnt_loss import rnnt_loss_simple
from .rnnt_loss import rnnt_loss_smoothed
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey, Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import _fast_rnnt
from torch import Tensor
from typing import Tuple, Optional, Sequence, Union, List
class MutualInformationRecursionFunction(torch.autograd.Function):
"""A recursion that is useful in computing mutual information between two
sequences of real vectors, but may be useful more generally in
sequence-to-sequence tasks where monotonic alignment between pairs of
sequences is desired.
"""
@staticmethod
def forward(
ctx,
px: torch.Tensor,
py: torch.Tensor,
pxy_grads: List[Optional[torch.Tensor]],
boundary: Optional[torch.Tensor] = None,
return_grad: bool = False,
) -> torch.Tensor:
"""
Computing mutual information between two sequences of real vectors.
Args:
px:
A torch.Tensor of some floating point type, with shape
``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the
length of the ``x`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols), and ``S`` is the
length of the ``y`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols). In the mutual
information application, ``px[b][s][t]`` would represent the
following log odds ratio; ignoring the b index on the right
to make the notation more
compact::
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an ``x`` value as opposed to a ``y`` value. In
practice it might be computed as ``a + b``, where ``a`` is the log
probability of choosing to extend the sequence of length ``(s,t)``
with an ``x`` as opposed to a ``y`` value; and ``b`` might in
practice be of the form::
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where ``N`` is the number of terms that the sum over ``t'``
included, which might include some or all of the other sequences as
well as this one.
Note:
we don't require ``px`` and py to be contiguous, but the
code assumes for optimization purposes that the ``T`` axis has
stride 1.
py:
A torch.Tensor of the same dtype as ``px``, with shape
``[B][S+1][T]``, representing::
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat ``x`` and ``y`` differently; the only
difference is that for optimization purposes we assume the last axis
(the ``t`` axis) has stride of 1; this is true if ``px`` and ``py``
are contiguous.
pxy_grads:
A List to store the return grads of ``px`` and ``py``
if return_grad == True.
Remain unchanged if return_grad == False.
See `this PR <https://github.com/k2-fsa/k2/pull/924>` for more
information about why we add this parameter.
Note:
the length of the list must be 2, where the first element
represents the grads of ``px`` and the second one represents
the grads of ``py``.
boundary:
If supplied, a torch.LongTensor of shape ``[B][4]``, where each
row contains ``[s_begin, t_begin, s_end, t_end]``,
with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T``
(this implies that empty sequences are allowed).
If not supplied, the values ``[0, 0, S, T]`` will be assumed.
These are the beginning and one-past-the-last positions in the
``x`` and ``y`` sequences respectively, and can be used if not
all sequences are
of the same length.
return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing
for the occupation probability is the output of the backward with a
``fake gradient`` the ``fake gradient`` is the same as the gradient
you'd get if you did
``torch.autograd.grad((scores.sum()), [px, py])``.
This is useful to implement the pruned version of rnnt loss.
Returns:
Returns a torch.Tensor of shape ``[B]``, containing the log of
the mutual information between the b'th pair of sequences. This is
defined by the following recursion on ``p[b,s,t]`` (where ``p``
is of shape ``[B,S+1,T+1]``), representing a mutual information
between sub-sequences of lengths ``s`` and ``t``::
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
where we handle edge cases by treating quantities with negative
indexes as **-infinity**. The extension to cases where the
boundaries are specified should be obvious; it just works on
shorter sequences with offsets into ``px`` and ``py``.
"""
(B, S, T1) = px.shape
T = py.shape[-1]
assert T1 in [T, T + 1]
assert py.shape == (B, S + 1, T)
if boundary is not None:
assert boundary.shape == (B, 4)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# the mutual information of the pair of subsequences of x and y that
# are of length s and t respectively. p[0][0] will be 0.0 and p[S][T]
# is the mutual information of the entire pair of sequences,
# i.e. of lengths S and T respectively.
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _fast_rnnt.mutual_information_forward(px, py, boundary, p)
px_grad, py_grad = None, None
if return_grad or px.requires_grad or py.requires_grad:
ans_grad = torch.ones(B, device=px.device, dtype=px.dtype)
(px_grad, py_grad) = _fast_rnnt.mutual_information_backward(
px, py, boundary, p, ans_grad)
ctx.save_for_backward(px_grad, py_grad)
assert len(pxy_grads) == 2
pxy_grads[0] = px_grad
pxy_grads[1] = py_grad
return ans
@staticmethod
def backward(
ctx, ans_grad: Tensor
) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
(px_grad, py_grad) = ctx.saved_tensors
(B,) = ans_grad.shape
ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1)
px_grad *= ans_grad
py_grad *= ans_grad
return (px_grad, py_grad, None, None, None)
def mutual_information_recursion(
px: Tensor,
py: Tensor,
boundary: Optional[Tensor] = None,
return_grad: bool = False,
) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]:
"""A recursion that is useful in computing mutual information between two
sequences of real vectors, but may be useful more generally in
sequence-to-sequence tasks where monotonic alignment between pairs of
sequences is desired. The definitions of the arguments are definitions that
would be used when computing this type of mutual information, but you can
also view them as arbitrary quantities and just make use of the formula
computed by this function.
Args:
px:
A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``,
where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence
(including representations of ``EOS`` symbols but not ``BOS`` symbols),
and ``S`` is the length of the ``y`` sequence (including representations
of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information
application, ``px[b][s][t]`` would represent the following log odds
ratio; ignoring the b index on the right to make the notation more
compact::
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an ``x`` value as opposed to a ``y`` value. In
practice it might be computed as ``a + b``, where ``a`` is the log
probability of choosing to extend the sequence of length ``(s,t)``
with an ``x`` as opposed to a ``y`` value; and ``b`` might in practice
be of the form::
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where ``N`` is the number of terms that the sum over ``t'`` included,
which might include some or all of the other sequences as well as this
one.
Note:
we don't require ``px`` and py to be contiguous, but the
code assumes for optimization purposes that the ``T`` axis has
stride 1.
py:
A torch.Tensor of the same dtype as ``px``, with shape ``[B][S+1][T]``,
representing::
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat ``x`` and ``y`` differently; the only
difference is that for optimization purposes we assume the last axis
(the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` are
contiguous.
boundary:
If supplied, a torch.LongTensor of shape ``[B][4]``, where each
row contains ``[s_begin, t_begin, s_end, t_end]``,
with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T``
(this implies that empty sequences are allowed).
If not supplied, the values ``[0, 0, S, T]`` will be assumed.
These are the beginning and one-past-the-last positions in the ``x`` and
``y`` sequences respectively, and can be used if not all sequences are
of the same length.
return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing for the
occupation probability is the output of the backward with a
``fake gradient`` the ``fake gradient`` is the same as the gradient
you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``.
This is useful to implement the pruned version of rnnt loss.
Returns:
Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual
information between the b'th pair of sequences. This is defined by
the following recursion on ``p[b,s,t]`` (where ``p`` is of shape
``[B,S+1,T+1]``), representing a mutual information between sub-sequences
of lengths ``s`` and ``t``::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
where we handle edge cases by treating quantities with negative indexes
as **-infinity**. The extension to cases where the boundaries are
specified should be obvious; it just works on shorter sequences with
offsets into ``px`` and ``py``.
"""
assert px.ndim == 3
B, S, T1 = px.shape
T = py.shape[-1]
assert px.shape[-1] in [T, T + 1] # if T, then "modified".
assert py.shape == (B, S + 1, T)
assert px.dtype == py.dtype
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for s_begin, t_begin, s_end, t_end in boundary.tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
# The following assertions are for efficiency
assert px.is_contiguous()
assert py.is_contiguous()
pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad)
px_grad, py_grad = pxy_grads
return (scores, (px_grad, py_grad)) if return_grad else scores
def _inner_product(a: Tensor, b: Tensor) -> Tensor:
"""
Does inner product on the last dimension, with expected broadcasting,
i.e. equivalent to (a * b).sum(dim=-1)
without creating a large temporary.
"""
assert a.shape[-1] == b.shape[-1] # The last dim must be equal
a = a.unsqueeze(-2) # (..., 1, K)
b = b.unsqueeze(-1) # (..., K, 1)
c = torch.matmul(a, b) # (..., 1, 1)
return c.squeeze(-1).squeeze(-1)
def joint_mutual_information_recursion(
px: Sequence[Tensor],
py: Sequence[Tensor],
boundary: Optional[Tensor] = None,
) -> Sequence[Tensor]:
"""A recursion that is useful for modifications of RNN-T and similar loss
functions, where the recursion probabilities have a number of terms and you
want them reported separately. See mutual_information_recursion() for more
documentation of the basic aspects of this.
Args:
px:
a sequence of Tensors, each of the same shape [B][S][T+1]
py:
a sequence of Tensor, each of the same shape [B][S+1][T],
the sequence must be the same length as px.
boundary:
optionally, a LongTensor of shape [B][4] containing rows
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S
and 0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
These are the beginning and one-past-the-last positions in the x
and y sequences respectively, and can be used if not all
sequences are of the same length.
Returns:
a Tensor of shape (len(px), B),
whose sum over dim 0 is the total log-prob of the recursion mentioned
below, per sequence. The first element of the sequence of length len(px)
is "special", in that it has an offset term reflecting the difference
between sum-of-log and log-of-sum; for more interpretable loss values,
the "main" part of your loss function should be first.
The recursion below applies if boundary == None, when it defaults
to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px
and py::
p = tensor of shape (B, S+1, T+1), containing -infinity
p[b,0,0] = 0.0
# do the following in loop over s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
p[b,s,t-1] + py_sum[b,s,t-1])
(if s > 0 or t > 0)
return b[:][S][T]
This function lets you implement the above recursion efficiently, except
that it gives you a breakdown of the contribution from all the elements of
px and py separately. As noted above, the first element of the
sequence is "special".
"""
N = len(px)
assert len(py) == N and N > 0
B, S, T1 = px[0].shape
T = py[0].shape[2]
assert T1 in [T, T + 1] # T if modified...
assert py[0].shape == (B, S + 1, T)
assert px[0].dtype == py[0].dtype
px_cat = torch.stack(
px, dim=0
) # (N, B, S, T+1) if !modified,(N, B, S, T) if modified.
py_cat = torch.stack(py, dim=0) # (N, B, S+1, T)
px_tot = px_cat.sum(dim=0) # (B, S, T+1)
py_tot = py_cat.sum(dim=0) # (B, S+1, T)
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for s_begin, t_begin, s_end, t_end in boundary.tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
# The following assertions are for efficiency
assert px_tot.ndim == 3
assert py_tot.ndim == 3
p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)
# note, tot_probs is without grad.
tot_probs = _fast_rnnt.mutual_information_forward(px_tot, py_tot, boundary, p)
# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)
(px_grad,
py_grad) = _fast_rnnt.mutual_information_backward(px_tot, py_tot, boundary, p,
ans_grad)
px_grad = px_grad.reshape(1, B, -1)
py_grad = py_grad.reshape(1, B, -1)
px_cat = px_cat.reshape(N, B, -1)
py_cat = py_cat.reshape(N, B, -1)
# get rid of -inf, would generate nan on product with 0
px_cat = px_cat.clamp(min=torch.finfo(px_cat.dtype).min)
py_cat = py_cat.clamp(min=torch.finfo(py_cat.dtype).min)
x_prods = _inner_product(px_grad, px_cat) # (N, B)
y_prods = _inner_product(py_grad, py_cat) # (N, B)
# If all the occupation counts were exactly 1.0 (i.e. no partial counts),
# "prods" should be equal to "tot_probs"; however, in general, "tot_probs"
# will be more positive due to the difference between log-of-sum and
# sum-of-log
prods = x_prods + y_prods # (N, B)
with torch.no_grad():
offset = tot_probs - prods.sum(dim=0) # (B,)
prods[0] += offset
return prods # (N, B)
# Copyright 2021 Xiaomi Corp. (author: Daniel Povey, Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import k2
import torch
from torch import Tensor
from typing import Optional, Tuple, Union
from .mutual_information import mutual_information_recursion
def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
None. If boundary == None and modified == False, px[:,:,-1] will
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
to be -infinity.
Args:
px: a Tensor of of shape [B][S][T+1] (this function is only
called if modified == False, see other docs for `modified`)
px is modified in-place and returned.
boundary: None, or a Tensor of shape [B][3] containing
[s_begin, t_begin, s_end, t_end]; we need only t_end.
"""
if boundary is None:
return px
B, S, T1 = px.shape
boundary = boundary[:, 3].reshape(B, 1, 1).expand(B, S, T1)
return px.scatter_(dim=2, index=boundary, value=float("-inf"))
def get_rnnt_logprobs(
lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just
addition), to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This function is called from rnnt_loss_simple(), but may be useful for
other purposes.
Args:
lm:
Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape::
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am:
Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape::
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
of length s and t respectively. px[b][s][t] represents the
probability of extending the subsequences of length (s,t) by one in
the s direction, given the particular symbol, and py[b][s][t]
represents the probability of extending the subsequences of length
(s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim == 3
assert am.ndim == 3
assert lm.shape[0] == am.shape[0]
assert lm.shape[2] == am.shape[2]
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1]
lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1]
am_probs = (am - am_max).exp()
lm_probs = (lm - lm_max).exp()
# normalizers: [B][S+1][T]
normalizers = (
torch.matmul(lm_probs, am_probs.transpose(1, 2))
+ torch.finfo(am_probs.dtype).tiny
).log()
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T]
# px is the probs of the actual symbols..
px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C),
dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
).squeeze(
-1
) # [B][S][T]
if not modified:
px_am = torch.cat(
(
px_am,
torch.full(
(B, S, 1),
float("-inf"),
device=px_am.device,
dtype=px_am.dtype,
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
px = px_am + px_lm # [B][S][T+1], last slice with indexes out of
# boundary is -inf
px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am = am[:, :, termination_symbol].unsqueeze(1) # [B][1][T]
py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers
if not modified:
px = fix_for_boundary(px, boundary)
return (px, py)
def rnnt_loss_simple(
lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]:
"""A simple case of the RNN-T loss, where the 'joiner' network is just
addition.
Args:
lm:
language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes
am:
acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols:
the symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
return_grad:
Whether to return grads of px and py, this grad standing for the
occupation probability is the output of the backward with a
`fake gradient`, the `fake gradient` is the same as the gradient you'd
get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the
loss here is the loss with reduction "none".
This is useful to implement the pruned version of rnnt loss.
Returns:
If return_grad is False, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch if reduction equals
to "none", otherwise a scalar with the reduction applied.
If return_grad is True, the grads of px and py, which is the output of
backward with a `fake gradient`(see above), will be returned too. And the
returned value will be a tuple like (loss, (px_grad, py_grad)).
"""
px, py = get_rnnt_logprobs(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
loss = -negated_loss
elif reduction == "mean":
loss = -torch.mean(negated_loss)
elif reduction == "sum":
loss = -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
return (loss, scores_and_grads[1]) if return_grad else loss
def get_rnnt_logprobs_joint(
logits: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Reduces RNN-T problem to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This function is called from rnnt_loss().
Args:
logits:
The output of joiner network, with shape (B, T, S + 1, C),
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
(px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1]
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert logits.ndim == 4
(B, T, S1, C) = logits.shape
S = S1 - 1
assert symbols.shape == (B, S)
normalizers = torch.logsumexp(logits, dim=3)
normalizers = normalizers.permute((0, 2, 1))
px = torch.gather(
logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1)
).squeeze(-1)
px = px.permute((0, 2, 1))
if not modified:
px = torch.cat(
(
px,
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
px[:, :, :T] -= normalizers[:, :S, :]
py = (
logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone()
) # [B][S+1][T]
py -= normalizers
px = px.contiguous()
py = py.contiguous()
if not modified:
px = fix_for_boundary(px, boundary)
return (px, py)
def rnnt_loss(
logits: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A normal RNN-T loss, which uses a 'joiner' network output as input,
i.e. a 4 dimensions tensor.
Args:
logits:
The output of joiner network, with shape (B, T, S + 1, C),
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes
symbols:
The symbol sequences, a LongTensor of shape [B][S], and elements
in {0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
Returns:
If recursion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar
with the reduction applied.
"""
px, py = get_rnnt_logprobs_joint(
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none":
return -negated_loss
elif reduction == "mean":
return -torch.mean(negated_loss)
elif reduction == "sum":
return -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
def _adjust_pruning_lower_bound(
s_begin: torch.Tensor, s_range: int
) -> torch.Tensor:
"""Adjust s_begin (pruning lower bound) to make it satisfied the following
constrains
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- start with symbol 0 at first frame.
- s_begin[i + 1] - s_begin[i] < s_range, whicn means that we can't skip
any symbols.
To make it monotonic increasing, we can use `monotonic_lower_bound` function
in k2, which guarantee `s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`, the initial `min_value` set to
`inf`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constrain is a little tricky. We first transform `s_begin` with
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
then we make the transformed `s_begin` monotonic increasing, after that,
we transform back `s_begin` with the same formula as the previous
transformation. The idea is: if we want to make
`s_begin[i + 1] - s_begin[i] < s_range` we only need to make
`-(s_begin[i] - i * (s_range - 1))` a non-decreasing array. Proof:
-(s_begin[i] - i * (s_range - 1)) <= -(s_begin[i + 1] - (i + 1) * (s_range - 1))
-s_begin[i] <= -s_begin[i + 1] + (i + 1) * (s_range - 1) - i * (s_range - 1)
-s_begin[i] <= -s_begin[i + 1] + s_range - 1
s_begin[i + 1] - s_begin[i] <= s_range - 1
s_begin[i + 1] - s_begin[i] < s_range
The above transformation can not guarantee the start symbol to be 0, so we
have to make all the elements that less than 0 to be 0 before transforming
back the `s_begin`.
"""
# s_begin (B, T)
(B, T) = s_begin.shape
s_begin = k2.monotonic_lower_bound(s_begin)
# do the magic transformation
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
# make the transformed tensor to be non-decreasing
s_begin = k2.monotonic_lower_bound(s_begin)
# make start symbol to be zero.
s_begin = torch.where(s_begin < 0, 0, s_begin)
# do the magic transformation again to recover s_begin
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
return s_begin
def get_rnnt_prune_ranges(
px_grad: torch.Tensor,
py_grad: torch.Tensor,
boundary: torch.Tensor,
s_range: int,
) -> torch.Tensor:
"""Get the pruning ranges of normal rnnt loss according to the grads
of px and py returned by mutual_information_recursion.
For each sequence with T frames, we will generate a tensor with the shape of
(T, s_range) containing the information that which symbols will be token
into consideration for each frame. For example, here is a sequence with 10
frames and the corresponding symbols are `[A B C D E F]`, if the s_range
equals 3, one possible ranges tensor will be::
[[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [1, 2, 3],
[1, 2, 3], [1, 2, 3], [3, 4, 5], [3, 4, 5], [3, 4, 5]]
which means we only consider `[A B C]` at frame 0, 1, 2, 3, and `[B C D]`
at frame 4, 5, 6, `[D E F]` at frame 7, 8, 9.
We can only consider limited number of symbols because frames and symbols
are monotonic aligned, theoretically it can only generate particular range
of symbols given a particular frame.
Note:
For the generated tensor ranges, ranges[:, 0] is a monotonic increasing
tensor from 0 to `len(symbols)` and it satisfies
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any
symbols.
Args:
px_grad:
The gradient of px, see docs in `mutual_information_recursion` for more
details of px.
py_grad:
The gradient of py, see docs in `mutual_information_recursion` for more
details of py.
boundary:
a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame]
s_range:
How many symbols to keep for each frame.
Returns:
A tensor contains the kept symbols indexes for each frame, with shape
(B, T, s_range).
"""
(B, S, T1) = px_grad.shape
T = py_grad.shape[-1]
assert T1 in [T, T + 1]
assert py_grad.shape == (B, S + 1, T)
assert boundary.shape == (B, 4)
assert s_range >= 1
if s_range > S:
s_range = S
px_pad = torch.zeros((B, 1, T1), dtype=px_grad.dtype, device=px_grad.device)
py_pad = torch.zeros(
(B, S + 1, 1), dtype=py_grad.dtype, device=py_grad.device
)
py_grad_padded = py_grad if T1 == T else torch.cat((py_grad, py_pad), dim=2)
tot_grad = (
torch.cat((px_grad, px_pad), dim=1) + py_grad_padded
) # (B, S + 1, T1)
tot_grad = torch.cat(
(
torch.zeros(
(B, 1, T1), dtype=tot_grad.dtype, device=tot_grad.device
),
tot_grad,
),
dim=1,
)
tot_grad = torch.cumsum(tot_grad, dim=1)
diff_grad = tot_grad[:, s_range:, :] - tot_grad[:, 0:-s_range, :]
s_begin = torch.argmax(diff_grad, dim=1)
s_begin = s_begin[:, :T]
# Handle the values of s_begin in padding positions.
# -1 here means we fill the position of the last frame of real data with
# padding value which is `len(symbols) - s_range + 1`.
# This is to guarantee that we reach the last symbol at last frame of real
# data.
mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T)
mask = mask < boundary[:, 3].reshape(B, 1) - 1
s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1
# handle the cases when `len(symbols) < s_range`
s_begin_padding = torch.where(s_begin_padding >= 0, s_begin_padding, 0)
s_begin = torch.where(mask, s_begin, s_begin_padding)
# adjusting lower bound to make it satisfied some constrains, see docs in
# `adjust_pruning_lower_bound` for more details of these constrains.
# T1 == T here means we are using the modified version of transducer,
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# it only emits one symbol per frame.
s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range)
ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange(
s_range, device=px_grad.device
)
return ranges
def do_rnnt_pruning(
am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prune the output of encoder(am) output and prediction network(lm)
output of RNNT.
Args:
am:
The encoder output, with shape (B, T, C)
lm:
The prediction network output, with shape (B, S + 1, C)
ranges:
A tensor containing the symbol indexes for each frame that we want to
keep. Its shape is (B, T, s_range), see the docs in
`get_rnnt_prune_ranges` for more details of this tensor.
Returns:
Return the pruned am and lm with shape (B, T, s_range, C)
"""
# am (B, T, C)
# lm (B, S + 1, C)
# ranges (B, T, s_range)
assert ranges.shape[0] == am.shape[0]
assert ranges.shape[0] == lm.shape[0]
assert am.shape[1] == ranges.shape[1]
(B, T, s_range) = ranges.shape
(B, S1, C) = lm.shape
S = S1 - 1
# (B, T, s_range, C)
am_pruning = am.unsqueeze(2).expand((B, T, s_range, C))
# (B, T, s_range, C)
lm_pruning = torch.gather(
lm.unsqueeze(1).expand((B, T, S + 1, C)),
dim=2,
index=ranges.reshape((B, T, s_range, 1)).expand((B, T, s_range, C)),
)
return am_pruning, lm_pruning
def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
"""Roll tensor with different shifts for each row.
Note:
We assume the src is a 3 dimensions tensor and roll the last dimension.
Example:
>>> src = torch.arange(15).reshape((1,3,5))
>>> src
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]]])
>>> shift = torch.tensor([[1, 2, 3]])
>>> shift
tensor([[1, 2, 3]])
>>> _roll_by_shifts(src, shift)
tensor([[[ 4, 0, 1, 2, 3],
[ 8, 9, 5, 6, 7],
[12, 13, 14, 10, 11]]])
"""
assert src.dim() == 3
(B, T, S) = src.shape
assert shifts.shape == (B, T)
index = (
torch.arange(S, device=src.device)
.view((1, S))
.repeat((T, 1))
.repeat((B, 1, 1))
)
index = (index - shifts.reshape(B, T, 1)) % S
return torch.gather(src, 2, index)
def get_rnnt_logprobs_pruned(
logits: Tensor,
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor,
modified: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output.
Args:
logits:
The pruned output of joiner network, with shape (B, T, s_range, C)
symbols:
The symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
Return the px (B, S, T) if modified else (B, S, T + 1) and
py (B, S + 1, T) needed by mutual_information_recursion.
"""
# logits (B, T, s_range, C)
# symbols (B, S)
# ranges (B, T, s_range)
assert logits.ndim == 4
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range)
(B, S) = symbols.shape
normalizers = torch.logsumexp(logits, dim=3)
symbols_with_terminal = torch.cat(
(
symbols,
torch.tensor(
[termination_symbol] * B,
dtype=torch.int64,
device=symbols.device,
).reshape((B, 1)),
),
dim=1,
)
# (B, T, s_range)
pruned_symbols = torch.gather(
symbols_with_terminal.unsqueeze(1).expand((B, T, S + 1)),
dim=2,
index=ranges,
)
# (B, T, s_range)
px = torch.gather(
logits, dim=3, index=pruned_symbols.reshape(B, T, s_range, 1)
).squeeze(-1)
px = px - normalizers
# (B, T, S) with index larger than s_range in dim 2 fill with -inf
px = torch.cat(
(
px,
torch.full(
(B, T, S + 1 - s_range),
float("-inf"),
device=px.device,
dtype=px.dtype,
),
),
dim=2,
)
# (B, T, S) with index out of s_range in dim 2 fill with -inf
px = _roll_by_shifts(px, ranges[:, :, 0])[:, :, :S]
px = px.permute((0, 2, 1))
if not modified:
px = torch.cat(
(
px,
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
py = logits[:, :, :, termination_symbol].clone() # (B, T, s_range)
py = py - normalizers
# (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf
py = torch.cat(
(
py,
torch.full(
(B, T, S + 1 - s_range),
float("-inf"),
device=py.device,
dtype=py.dtype,
),
),
dim=2,
)
# (B, T, S + 1) with index out of s_range in dim 2 fill with -inf
py = _roll_by_shifts(py, ranges[:, :, 0])
# (B, S + 1, T)
py = py.permute((0, 2, 1))
px = px.contiguous()
py = py.contiguous()
if not modified:
px = fix_for_boundary(px, boundary)
return (px, py)
def rnnt_loss_pruned(
logits: Tensor,
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor = None,
modified: bool = False,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
s_range means the symbols number kept for each frame.
Args:
logits:
The pruned output of joiner network, with shape (B, T, s_range, C),
i.e. batch, time_seq_len, prune_range, num_classes
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
Returns:
If recursion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar
with the reduction applied.
"""
px, py = get_rnnt_logprobs_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none":
return -negated_loss
elif reduction == "mean":
return -torch.mean(negated_loss)
elif reduction == "sum":
return -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
def get_rnnt_logprobs_smoothed(
lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None,
modified: bool = False,
) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just
addition), to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This version allows you to make the loss-function one of the form::
lm_only_scale * lm_probs +
am_only_scale * am_probs +
(1-lm_only_scale-am_only_scale) * combined_probs
where lm_probs and am_probs are the probabilities given the lm and acoustic
model independently.
This function is called from
:func:`rnnt_loss_smoothed`, but may be useful for other purposes.
Args:
lm:
Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape::
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am:
Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape::
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
lm_only_scale:
the scale on the "LM-only" part of the loss.
am_only_scale:
the scale on the "AM-only" part of the loss, for which we use
an "averaged" LM (averaged over all histories, so effectively unigram).
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
of length s and t respectively. px[b][s][t] represents the
probability of extending the subsequences of length (s,t) by one in
the s direction, given the particular symbol, and py[b][s][t]
represents the probability of extending the subsequences of length
(s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim == 3
assert am.ndim == 3
assert lm.shape[0] == am.shape[0]
assert lm.shape[2] == am.shape[2]
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
# all of the logprobs here are properly normalized. We test that
# this code is invariant to adding constants in the appropriate ways.
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1]
lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1]
am_probs = (am - am_max).exp() # [B][T][C]
lm_probs = (lm - lm_max).exp() # [B][S+1][C]
# normalizers: [B][S+1][T]
normalizers = (
torch.matmul(lm_probs, am_probs.transpose(1, 2))
+ torch.finfo(lm_probs.dtype).tiny
).log()
# normalizer per frame, if we take only the LM probs by themselves
lmonly_normalizers = lm_probs.sum(
dim=2, keepdim=True
) # lmonly_normalizers: [B][S+1][1]
unigram_lm = (
torch.mean(lm_probs / lmonly_normalizers, dim=(0, 1), keepdim=True)
+ torch.finfo(lm_probs.dtype).tiny
) # [1][1][C]
amonly_normalizers = (
torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C))
.reshape(B, T, 1)
.log()
+ am_max
) # [B][T][1]
amonly_normalizers = amonly_normalizers.transpose(1, 2) # [B][1][T]
unigram_lm = unigram_lm.log()
lmonly_normalizers = (
lmonly_normalizers.log() + lm_max
) # [B][S+1][1], log-normalizer, used for LM-only part of prob.
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T]
# px is the probs of the actual symbols (not yet normalized)..
px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C),
dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
).squeeze(
-1
) # [B][S][T]
if not modified:
px_am = torch.cat(
(
px_am,
torch.full(
(B, S, 1),
float("-inf"),
device=px_am.device,
dtype=px_am.dtype,
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
px_lm_unigram = torch.gather(
unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
px = px_am + px_lm # [B][S][T+1] if not modified, [B][S][T] if modified
px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] or [B][S][T]
px_amonly = (
px_am + px_lm_unigram
) # [B][S][T+1] if !modified; [B][S][T] if modified.
px_amonly[:, :, :T] -= amonly_normalizers
px_lmonly = px_lm - lmonly_normalizers[:, :S, :]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am = am[:, :, termination_symbol].unsqueeze(1) # [B][1][T]
py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers
py_lm_unigram = unigram_lm[0][0][termination_symbol] # scalar, normalized..
py_amonly = py_am + py_lm_unigram - amonly_normalizers # [B][S+1][T]
py_lmonly = py_lm - lmonly_normalizers # [B][S+1][T]
combined_scale = 1.0 - lm_only_scale - am_only_scale
# We need to avoid exact zeros in the scales because otherwise multiplying
# -inf by zero generates nan.
if lm_only_scale == 0.0:
lm_only_scale = 1.0e-20
if am_only_scale == 0.0:
am_only_scale = 1.0e-20
px_interp = (
px * combined_scale
+ px_lmonly * lm_only_scale
+ px_amonly * am_only_scale
)
py_interp = (
py * combined_scale
+ py_lmonly * lm_only_scale
+ py_amonly * am_only_scale
)
if not modified:
px_interp = fix_for_boundary(px_interp, boundary)
return (px_interp, py_interp)
def rnnt_loss_smoothed(
lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None,
modified: bool = False,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]:
"""A simple case of the RNN-T loss, where the 'joiner' network is just
addition.
Args:
lm:
language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes.
These are assumed to be well-normalized, in the sense that we could
use them as probabilities separately from the am scores
am:
acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols:
the symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
lm_only_scale:
the scale on the "LM-only" part of the loss.
am_only_scale:
the scale on the "AM-only" part of the loss, for which we use
an "averaged" LM (averaged over all histories, so effectively unigram).
boundary:
a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
return_grad:
Whether to return grads of px and py, this grad standing for the
occupation probability is the output of the backward with a
`fake gradient`, the `fake gradient` is the same as the gradient you'd
get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the
loss here is the loss with reduction "none".
This is useful to implement the pruned version of rnnt loss.
Returns:
If return_grad is False, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch if reduction equals
to "none", otherwise a scalar with the reduction applied.
If return_grad is True, the grads of px and py, which is the output of
backward with a `fake gradient`(see above), will be returned too. And the
returned value will be a tuple like (loss, (px_grad, py_grad)).
"""
px, py = get_rnnt_logprobs_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=lm_only_scale,
am_only_scale=am_only_scale,
boundary=boundary,
modified=modified,
)
scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
loss = -negated_loss
elif reduction == "mean":
loss = -torch.mean(negated_loss)
elif reduction == "sum":
loss = -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
return (loss, scores_and_grads[1]) if return_grad else loss
function(fast_rnnt_add_py_test source)
get_filename_component(name ${source} NAME_WE)
set(name "${name}_py")
add_test(NAME ${name}
COMMAND
"${PYTHON_EXECUTABLE}"
"${CMAKE_CURRENT_SOURCE_DIR}/${source}"
)
get_filename_component(fast_rnnt_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
set_property(TEST ${name}
PROPERTY ENVIRONMENT "PYTHONPATH=${fast_rnnt_path}:$<TARGET_FILE_DIR:_fast_rnnt>:$ENV{PYTHONPATH}"
)
endfunction()
# please sort the files in alphabetic order
set(py_test_files
mutual_information_test.py
rnnt_loss_test.py
)
foreach(source IN LISTS py_test_files)
fast_rnnt_add_py_test(${source})
endforeach()
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (authors: Daniel Povey,
# Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run this single test, use
#
# ctest --verbose -R mutual_information_test_py
import random
import unittest
import fast_rnnt
import torch
# Caution: this will fail occasionally due to cutoffs not being quite large
# enough. As long as it passes most of the time, it's OK.
class TestMutualInformation(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.devices = [torch.device("cpu")]
if torch.cuda.is_available():
cls.devices.append(torch.device("cuda", 0))
if torch.cuda.device_count() > 1:
torch.cuda.set_device(1)
cls.devices.append(torch.device("cuda", 1))
cls.dtypes = [torch.float32, torch.float64]
def test_mutual_information_basic(self):
for _iter in range(100):
(B, S, T) = (
random.randint(1, 10),
random.randint(1, 16),
random.randint(1, 500),
)
random_px = random.random() < 0.2
random_py = random.random() < 0.2
random_boundary = random.random() < 0.7
big_px = random.random() < 0.2
big_py = random.random() < 0.2
modified = random.random() < 0.5
if modified and T < S:
T = S + random.randint(0, 30)
for dtype in self.dtypes:
for device in self.devices:
if random_boundary:
def get_boundary_row():
this_S = random.randint(
0, S
) # allow empty sequence
this_T = random.randint(
this_S if modified else 1, T
)
s_begin = random.randint(0, S - this_S)
t_begin = random.randint(0, T - this_T)
s_end = s_begin + this_S
t_end = t_begin + this_T
return [s_begin, t_begin, s_end, t_end]
if device == torch.device("cpu"):
boundary = torch.tensor(
[get_boundary_row() for _ in range(B)],
dtype=torch.int64,
device=device,
)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly
# or not.
if random.random() < 0.5:
boundary = (
torch.tensor([0, 0, S, T], dtype=torch.int64)
.unsqueeze(0)
.expand(B, 4)
.to(device)
)
else:
boundary = None
if device == torch.device("cpu"):
if random_px:
# log of an odds ratio
px = torch.randn(
B, S, T + (0 if modified else 1), dtype=dtype
).to(device)
if S > 1 and not random_boundary and not modified:
px[:, :, -1:] = float("-inf")
else:
# log of an odds ratio
px = torch.zeros(
B, S, T + (0 if modified else 1), dtype=dtype
).to(device)
# px and py get exponentiated, and then multiplied
# together up to 32 times (BLOCK_SIZE in the CUDA code),
# so 15 is actually a big number that could lead to
# overflow.
if big_px:
px += 15.0
if random_py:
# log of an odds ratio
py = torch.randn(B, S + 1, T, dtype=dtype).to(
device
)
else:
# log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(
device
)
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
m = fast_rnnt.mutual_information_recursion(px, py, boundary)
m2 = fast_rnnt.joint_mutual_information_recursion(
(px,), (py,), boundary
)
m3 = fast_rnnt.joint_mutual_information_recursion(
(px * 0.5, px * 0.5), (py * 0.5, py * 0.5), boundary
)
# it is supposed to be identical only after
# summing over dim 0, corresponding to the
# sequence dim
m3 = m3.sum(dim=0)
assert torch.allclose(m, m2)
assert torch.allclose(m, m3)
# the loop this is in checks that the CPU and CUDA versions
# give the same derivative;
# by randomizing which of m, m2 or m3 we backprop, we also
# ensure that the joint version of the code gives the same
# derivative as the regular version
scale = 3
if random.random() < 0.5:
(m.sum() * scale).backward()
elif random.random() < 0.5:
(m2.sum() * scale).backward()
else:
(m3.sum() * scale).backward()
if device == torch.device("cpu"):
expected_px_grad = px.grad
expected_py_grad = py.grad
expected_m = m
assert torch.allclose(
px.grad,
expected_px_grad.to(device),
atol=1.0e-02,
rtol=1.0e-02,
)
assert torch.allclose(
py.grad,
expected_py_grad.to(device),
atol=1.0e-02,
rtol=1.0e-02,
)
assert torch.allclose(
m, expected_m.to(device), atol=1.0e-02, rtol=1.0e-02
)
def test_mutual_information_deriv(self):
for _iter in range(100):
(B, S, T) = (
random.randint(1, 100),
random.randint(1, 200),
random.randint(1, 200),
)
random_px = random.random() < 0.2
random_py = random.random() < 0.2
random_boundary = random.random() < 0.7
big_px = random.random() < 0.2
big_py = random.random() < 0.2
modified = random.random() < 0.5
if modified and T < S:
T = S + random.randint(0, 30)
for dtype in self.dtypes:
for device in self.devices:
if random_boundary:
def get_boundary_row():
this_S = random.randint(1, S)
this_T = random.randint(
this_S if modified else 1, T
)
s_begin = random.randint(0, S - this_S)
t_begin = random.randint(0, T - this_T)
s_end = s_begin + this_S
t_end = t_begin + this_T
return [s_begin, t_begin, s_end, t_end]
if device == torch.device("cpu"):
boundary = torch.tensor(
[get_boundary_row() for _ in range(B)],
dtype=torch.int64,
device=device,
)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly
# or not.
if random.random() < 0.5:
boundary = (
torch.tensor([0, 0, S, T], dtype=torch.int64)
.unsqueeze(0)
.expand(B, 4)
.to(device)
)
else:
boundary = None
T1 = T + (0 if modified else 1)
if device == torch.device("cpu"):
if random_px:
# log of an odds ratio
px = torch.randn(B, S, T1, dtype=dtype).to(device)
else:
# log of an odds ratio
px = torch.zeros(B, S, T1, dtype=dtype).to(device)
# px and py get exponentiated, and then multiplied
# together up to 32 times (BLOCK_SIZE in the CUDA code),
# so 15 is actually a big number that could lead to
# overflow.
if big_px:
px += 15.0
if random_py:
# log of an odds ratio
py = torch.randn(B, S + 1, T, dtype=dtype).to(
device
)
else:
# log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(
device
)
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
m = fast_rnnt.mutual_information_recursion(px, py, boundary)
m_grad = torch.randn(B, dtype=dtype, device=device)
m.backward(gradient=m_grad)
delta = 1.0e-04
delta_px = delta * torch.randn_like(px)
m2 = fast_rnnt.mutual_information_recursion(
px + delta_px, py, boundary
)
delta_m = m2 - m
observed_delta = (delta_m * m_grad).sum().to("cpu")
predicted_delta = (delta_px * px.grad).sum().to("cpu")
atol = 1.0e-02 if dtype == torch.float32 else 1.0e-04
rtol = 1.0e-02 if dtype == torch.float32 else 1.0e-04
assert torch.allclose(
observed_delta, predicted_delta, atol=atol, rtol=rtol
)
delta_py = delta * torch.randn_like(py)
m2 = fast_rnnt.mutual_information_recursion(
px, py + delta_py, boundary
)
delta_m = m2 - m
observed_delta = (delta_m * m_grad).sum().to("cpu")
predicted_delta = (delta_py * py.grad).sum().to("cpu")
assert torch.allclose(
observed_delta, predicted_delta, atol=atol, rtol=rtol
)
if __name__ == "__main__":
unittest.main()
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (authors: Daniel Povey,
# Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run this single test, use
#
# ctest --verbose -R rnnt_loss_test_py
import unittest
import fast_rnnt
import random
import torch
class TestRnntLoss(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.devices = [torch.device("cpu")]
if torch.cuda.is_available():
cls.devices.append(torch.device("cuda", 0))
if torch.cuda.device_count() > 1:
torch.cuda.set_device(1)
cls.devices.append(torch.device("cuda", 1))
try:
import torchaudio
import torchaudio.functional
if hasattr(torchaudio.functional, "rnnt_loss"):
cls.has_torch_rnnt_loss = True
else:
cls.has_torch_rnnt_loss = False
print(
f"Current torchaudio version: {torchaudio.__version__}\n"
"Skipping the tests of comparing rnnt loss with torch "
"one, to enable these tests please install a "
"version >= 0.10.0"
)
except ImportError as e:
cls.has_torch_rnnt_loss = False
print(
f"Import torchaudio error, error message: {e}\n"
"Skipping the tests of comparing rnnt loss with torch "
"one, to enable these tests, please install torchaudio "
"with version >= 0.10.0"
)
def test_rnnt_loss_basic(self):
B = 1
S = 3
T = 4
# C = 3
for device in self.devices:
# lm: [B][S+1][C]
lm = torch.tensor(
[[[0, 0, 1], [0, 1, 1], [1, 0, 1], [2, 2, 0]]],
dtype=torch.float,
device=device,
)
# am: [B][T][C]
am = torch.tensor(
[[[0, 1, 2], [0, 0, 0], [0, 2, 4], [0, 3, 3]]],
dtype=torch.float,
device=device,
)
termination_symbol = 2
symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device)
px, py = fast_rnnt.get_rnnt_logprobs(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
)
assert px.shape == (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(px=px, py=py, boundary=None)
if device == torch.device("cpu"):
expected = -m
assert torch.allclose(-m, expected.to(device))
# test rnnt_loss_simple
m = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
# test rnnt_loss_smoothed
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
# test rnnt_loss
m = fast_rnnt.rnnt_loss(
logits=probs,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
# compare with torchaudio rnnt_loss
if self.has_torch_rnnt_loss:
import torchaudio.functional
m = torchaudio.functional.rnnt_loss(
logits=probs,
targets=symbols.int(),
logit_lengths=torch.tensor(
[T] * B, dtype=torch.int32, device=device
),
target_lengths=torch.tensor(
[S] * B, dtype=torch.int32, device=device
),
blank=termination_symbol,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
# should be invariant to adding a constant for any frame.
lm += torch.randn(B, S + 1, 1, device=device)
am += torch.randn(B, T, 1, device=device)
m = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss(
logits=probs,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
def test_rnnt_loss_random(self):
B = 5
S = 20
T = 300
C = 100
frames = torch.randint(S, T, (B,))
seq_length = torch.randint(3, S - 1, (B,))
T = torch.max(frames)
S = torch.max(seq_length)
am_ = torch.randn((B, T, C), dtype=torch.float32)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float32)
symbols_ = torch.randint(0, C - 1, (B, S))
termination_symbol = C - 1
boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for modified in [True, False]:
for device in self.devices:
# lm: [B][S+1][C]
lm = lm_.to(device)
# am: [B][T][C]
am = am_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)
px, py = fast_rnnt.get_rnnt_logprobs(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
assert px.shape == (B, S, T) if modified else (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(
px=px, py=py, boundary=boundary
)
if device == torch.device("cpu"):
expected = -torch.mean(m)
assert torch.allclose(-torch.mean(m), expected.to(device))
m = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss(
logits=probs,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
# compare with torchaudio rnnt_loss
if self.has_torch_rnnt_loss and not modified:
import torchaudio.functional
m = torchaudio.functional.rnnt_loss(
logits=probs,
targets=symbols.int(),
logit_lengths=boundary[:, 3].int(),
target_lengths=boundary[:, 2].int(),
blank=termination_symbol,
)
assert torch.allclose(m, expected.to(device))
# should be invariant to adding a constant for any frame.
lm += torch.randn(B, S + 1, 1, device=device)
am += torch.randn(B, T, 1, device=device)
m = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss(
logits=probs,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
def test_rnnt_loss_gradient(self):
if self.has_torch_rnnt_loss:
import torchaudio.functional
B = 5
S = 20
T = 300
C = 100
frames = torch.randint(S, T, (B,))
seq_length = torch.randint(3, S - 1, (B,))
T = torch.max(frames)
S = torch.max(seq_length)
am_ = torch.randn((B, T, C), dtype=torch.float32)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float32)
symbols_ = torch.randint(0, C - 1, (B, S))
termination_symbol = C - 1
boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for device in self.devices:
# lm: [B][S+1][C]
lm = lm_.to(device)
# am: [B][T][C]
am = am_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)
logprobs = am.unsqueeze(2) + lm.unsqueeze(1)
logprobs.requires_grad_()
k2_loss = fast_rnnt.rnnt_loss(
logits=logprobs,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
)
k2_grad = torch.autograd.grad(k2_loss, logprobs)
k2_grad = k2_grad[0]
logprobs2 = logprobs.detach().clone().float()
logprobs2.requires_grad_()
torch_loss = torchaudio.functional.rnnt_loss(
logits=logprobs2,
targets=symbols.int(),
logit_lengths=boundary[:, 3].int(),
target_lengths=boundary[:, 2].int(),
blank=termination_symbol,
)
torch_grad = torch.autograd.grad(torch_loss, logprobs2)
torch_grad = torch_grad[0]
assert torch.allclose(k2_loss, torch_loss, atol=1e-2, rtol=1e-2)
assert torch.allclose(k2_grad, torch_grad, atol=1e-2, rtol=1e-2)
def test_rnnt_loss_smoothed(self):
B = 1
S = 3
T = 4
# C = 3
for device in self.devices:
# lm: [B][S+1][C]
lm = torch.tensor(
[[[0, 0, 1], [0, 1, 1], [1, 0, 1], [2, 2, 0]]],
dtype=torch.float,
device=device,
)
# am: [B][T][C]
am = torch.tensor(
[[[0, 1, 2], [0, 0, 0], [0, 2, 4], [0, 3, 3]]],
dtype=torch.float,
device=device,
)
termination_symbol = 2
symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device)
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.333,
boundary=None,
)
if device == torch.device("cpu"):
expected = m
assert torch.allclose(m, expected.to(device))
# should be invariant to adding a constant for any frame.
lm += torch.randn(B, S + 1, 1, device=device)
am += torch.randn(B, T, 1, device=device)
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.333,
boundary=None,
)
assert torch.allclose(m, expected.to(device))
def test_rnnt_loss_pruned(self):
B = 4
T = 300
S = 50
C = 10
frames = torch.randint(S, T, (B,))
seq_length = torch.randint(3, S - 1, (B,))
T = torch.max(frames)
S = torch.max(seq_length)
am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(0, C - 1, (B, S))
terminal_symbol = C - 1
boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for modified in [True, False]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)
t_am = am.unsqueeze(2).float()
t_lm = lm.unsqueeze(1).float()
t_prob = t_am + t_lm
# nonlinear transform
t_prob = torch.sigmoid(t_prob)
k2_loss = fast_rnnt.rnnt_loss(
logits=t_prob,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
)
print(
f"unpruned rnnt loss with modified {modified} : {k2_loss}"
)
# pruning
k2_simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
return_grad=True,
reduction="none",
)
for r in range(2, 50, 5):
ranges = fast_rnnt.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
am_p, lm_p = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)
t_prob_p = am_p + lm_p
# nonlinear transform
t_prob_p = torch.sigmoid(t_prob_p)
pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=t_prob_p,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
reduction="none",
)
print(f"pruning loss with range {r} : {pruned_loss}")
if __name__ == "__main__":
unittest.main()
import os
import random
import time
import unittest
import torch
from tqdm import tqdm
from torch_discounted_cumsum import discounted_cumsum_left, discounted_cumsum_right
def get_grad(param, out):
out.sum().backward()
grad = param.grad.clone()
del param.grad
return grad
def discounted_cumsum_left_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in range(input.shape[1]):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.append(last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_right_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in reversed(range(input.shape[1])):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.insert(0, last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_lib(x, gamma, dir):
return {
'left': discounted_cumsum_left,
'right': discounted_cumsum_right,
}[dir](x, gamma)
def discounted_cumsum_gold(x, gamma, dir):
return {
'left': discounted_cumsum_left_gold,
'right': discounted_cumsum_right_gold,
}[dir](x, gamma)
def compute_linf(batchsz, veclen, dir, gamma=0.99, dtype=torch.float32, cuda=False, data='randn', tol=1e-3, seed=2021):
torch.manual_seed(seed)
if data == 'randn':
x = torch.randn((batchsz, veclen), dtype=dtype)
elif data == 'ones':
x = torch.ones((batchsz, veclen), dtype=dtype)
else:
raise ValueError('Invalid data generation identifier')
if cuda:
x = x.cuda()
x = torch.nn.Parameter(x)
out_gold = discounted_cumsum_gold(x, gamma, dir)
grad_gold = get_grad(x, out_gold)
out_lib = discounted_cumsum_lib(x, gamma, dir)
grad_lib = get_grad(x, out_lib)
out_linf = (out_lib - out_gold).abs().max().item()
grad_linf = (grad_lib - grad_gold).abs().max().item()
if out_linf >= tol or grad_linf >= tol:
print(f'x={x}\nout_gold={out_gold}\nout_lib={out_lib}\ngrad_gold={grad_gold}\ngrad_lib={grad_lib}\n')
return out_linf, grad_linf
class TestDiscountedCumSum(unittest.TestCase):
def test_validity(self):
print('Testing validity...')
is_cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''
for cuda in (True, False):
if cuda and not is_cuda:
print('Skipping validity CUDA tests')
continue
rng = random.Random(2021)
with tqdm(total=2*2*2*17) as pbar:
for data in ('ones', 'randn'):
for dtype in (torch.float32, torch.float64):
for i in range(2):
batchsz = 8 ** i
for j in range(17):
veclen = max(1, 2 ** j + rng.randint(-1, 1))
gamma = rng.random()
seed = rng.randint(0, 2 ** 16)
dir = rng.choice(['left', 'right'])
tol = 2e-3
out_linf, grad_linf = compute_linf(
batchsz, veclen, dir, gamma, dtype, cuda, data, tol, seed
)
msg = f'Validity test failed with batchsz={batchsz}, veclen={veclen}, dir={dir}, ' \
f'gamma={gamma}, dtype={dtype}, cuda={cuda}, data={data}, seed={seed}, ' \
f'out_linf={out_linf}, grad_linf={grad_linf}'
self.assertLess(out_linf, tol, msg)
self.assertLess(grad_linf, tol, msg)
pbar.update(1)
def test_precision(self):
print('Testing precision...')
is_cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''
if not is_cuda:
print('Skipping precision tests')
return
batchsz = 1
veclen = 10000
gamma = 0.99
dir = 'right'
for data in ('ones', 'randn'):
if data == 'ones':
precision_factor = 2.0
else:
precision_factor = 1.1
torch.manual_seed(2021)
if data == 'randn':
x_32 = torch.randn((batchsz, veclen), dtype=torch.float32)
elif data == 'ones':
x_32 = torch.ones((batchsz, veclen), dtype=torch.float32)
else:
raise ValueError('Invalid data generation identifier')
x_32 = x_32.cuda()
x_64 = x_32.double()
gold_64 = discounted_cumsum_gold(x_64, gamma, dir)
gold_32 = discounted_cumsum_gold(x_32, gamma, dir).double()
lib_32 = discounted_cumsum_lib(x_32, gamma, dir).double()
err_32_gold = (gold_32 - gold_64).abs().max().item()
err_32_lib = (lib_32 - gold_64).abs().max().item()
msg = f'Precision improvement test failed with data={data}, ' \
f'err_32_gold={err_32_gold}, err_32_lib={err_32_lib}'
self.assertLess(precision_factor * err_32_lib, err_32_gold, msg)
print(f'data={data}\nerr_32_gold={err_32_gold:10.8f}\nerr_32_lib ={err_32_lib:10.8f}')
def test_speed(self):
print('Testing speed...')
is_cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''
NUM_RUNS = 30
NUM_RUNS_GOLD = 6
if not is_cuda:
print('Skipping speed tests')
return
gamma = 0.99
x_32 = torch.randn((1, 100000), dtype=torch.float32)
x_32 += torch.ones_like(x_32)
x_32_gpu = x_32.cuda()
timer = time.clock_gettime(time.CLOCK_MONOTONIC)
for _ in tqdm(range(NUM_RUNS_GOLD), desc='gold', leave=True):
discounted_cumsum_right_gold(x_32, gamma)
dur_gold = time.clock_gettime(time.CLOCK_MONOTONIC) - timer
dur_gold = dur_gold * NUM_RUNS / NUM_RUNS_GOLD
timer = time.clock_gettime(time.CLOCK_MONOTONIC)
for _ in tqdm(range(NUM_RUNS), desc='lib_cpu', leave=True):
discounted_cumsum_right(x_32, gamma)
dur_lib_cpu = time.clock_gettime(time.CLOCK_MONOTONIC) - timer
timer = time.clock_gettime(time.CLOCK_MONOTONIC)
for _ in tqdm(range(NUM_RUNS), desc='lib_cuda', leave=True):
discounted_cumsum_right(x_32_gpu, gamma)
dur_lib_cuda = time.clock_gettime(time.CLOCK_MONOTONIC) - timer
print(f'dur_gold: {dur_gold:7.4f} sec')
print(f'dur_lib_cpu: {dur_lib_cpu:7.4f} sec')
print(f'dur_lib_cuda: {dur_lib_cuda:7.4f} sec')
print(f'speedup gold -> lib_cpu: {dur_gold / dur_lib_cpu:5.2f}')
print(f'speedup gold -> lib_cuda: {dur_gold / dur_lib_cuda:5.2f}')
print(f'speedup lib_cpu -> lib_cuda: {dur_lib_cpu / dur_lib_cuda:5.2f}')
if __name__ == '__main__':
unittest.main()
from .mutual_information import mutual_information_recursion, joint_mutual_information_recursion
from .rnnt import get_rnnt_logprobs, rnnt_loss_simple, rnnt_loss_aux
import os
import torch
from torch import Tensor
from typing import Tuple, Optional, Sequence
from torch.utils.cpp_extension import load
VERBOSE = False
def _resolve(name):
return os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
try:
import torch_mutual_information_cpu
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_mutual_information_cpu')
torch_mutual_information_cpu = load(
name='torch_mutual_information_cpu',
sources=[
_resolve('mutual_information_cpu.cpp'),
],
verbose=VERBOSE,
)
try:
import torch_mutual_information_cuda
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_mutual_information_cuda')
torch_mutual_information_cuda = None
if torch.cuda.is_available():
torch_mutual_information_cuda = load(
name='torch_mutual_information_cuda',
sources=[
_resolve('mutual_information_cuda.cpp'),
_resolve('mutual_information_cuda_kernel.cu'),
],
verbose=VERBOSE,
)
def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundary: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if px.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_mutual_information_cuda.mutual_information_cuda(
px, py, boundary, p)
else:
return torch_mutual_information_cpu.mutual_information_cpu(
px, py, boundary, p)
def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundary: torch.Tensor, p: torch.Tensor,
ans_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if px.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
overwrite_ans_grad = True
if overwrite_ans_grad:
ans_grad_copy = ans_grad.clone()
ans = tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
px, py, boundary, p, ans_grad_copy, overwrite_ans_grad))
if overwrite_ans_grad:
if not torch.allclose(ans_grad, ans_grad_copy, rtol=1.0e-02):
print(f"Warning: possible excesssive roundoff in mutual information backward "
f"recursion: {ans_grad} vs. {ans_grad_copy}");
return ans
else:
return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu(
px, py, boundary, p, ans_grad))
class MutualInformationRecursionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundary: Optional[torch.Tensor]) -> torch.Tensor:
(B, S, T1) = px.shape
T = T1 - 1;
assert py.shape == (B, S + 1, T)
if boundary is not None:
assert boundary.shape == (B, 4)
else:
boundary = torch.zeros(0, 0, dtype=torch.int64, device=px.device)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _mutual_information_forward_dispatcher(px, py, boundary, p)
# print(f"p = {p}, boundary = {boundary}, psum={p.sum()}")
if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundary, p)
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
(px, py, boundary, p) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundary, p, ans_grad)
return (px_grad, py_grad, None)
def mutual_information_recursion(px, py, boundary=None):
"""A recursion that is useful in computing mutual information between two sequences of
real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
the arguments are definitions that would be used when computing this type of
mutual information, but you can also view them as arbitrary quantities and just
make use of the formula computed by this function.
Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T+1],
where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of
EOS symbols but not BOS symbols). In the mutual information application,
px[b][s][t] would represent the following log odds ratio; ignoring
the b index on the right to make the notation more compact,
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an x value as opposed to a y value. In
practice it might be computed as a + b, where a is the log
probability of choosing to extend the sequence of length (s,t)
with an x as opposed to a y value; and b might in practice be
of the form:
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where N is the number of terms that the sum over t' included, which
might include some or all of the other sequences as well as this one.
Note: we don't require px and py to be contiguous, but the
code assumes for optimization purposes that the T axis has
stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S+1][T],
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference
is that for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous.
boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T (this implies that empty sequences are allowed). If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutual
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S+1,T+1]),
representing a mutual information between sub-sequences of lengths s and t:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
where we handle edge cases by treating quantities with negative indexes
as -infinity. The extension to cases where the boundaries are specified
should be obvious; it just works on shorter sequences with offsets into
px and py.
"""
assert px.ndim == 3
B, S, T1 = px.shape
T = T1 - 1
assert py.shape == (B, S + 1, T)
assert px.dtype == py.dtype
(B, S, T) = px.shape
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for [ s_begin, t_begin, s_end, t_end ] in boundary.to('cpu').tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
# The following assertions are for efficiency
assert px.stride()[-1] == 1
assert py.stride()[-1] == 1
return MutualInformationRecursionFunction.apply(px, py, boundary)
def _inner(a: Tensor, b: Tensor) -> Tensor:
"""
Does inner product on the last dimension, with expected broadcasting, i.e. equivalent to
(a * b).sum(dim=-1)
without creating a large temporary.
"""
assert a.shape[-1] == b.shape[-1] # last last dim be K
a = a.unsqueeze(-2) # (..., 1, K)
b = b.unsqueeze(-1) # (..., K, 1)
c = torch.matmul(a, b) # (..., 1, 1)
return c.squeeze(-1).squeeze(-1)
def joint_mutual_information_recursion(px: Sequence[Tensor], py: Sequence[Tensor],
boundary: Optional[Tensor] = None) -> Sequence[Tensor]:
"""A recursion that is useful for modifications of RNN-T and similar loss functions,
where the recursion probabilities have a number of terms and you want them reported
separately. See mutual_information_recursion() for more documentation of the
basic aspects of this.
Args:
px: a sequence of Tensors, each of the same shape [B][S][T+1]
py: a sequence of Tensor, each of the same shape [B][S+1][T], the sequence must be
the same length as px.
boundary: optionally, a LongTensor of shape [B][4] containing rows
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
Returns:
a Tensor of shape (len(px), B),
whose sum over dim 0 is the total log-prob of the recursion mentioned below, per sequence.
The first element of the sequence of length len(px) is "special", in that it has an offset term
reflecting the difference between sum-of-log and log-of-sum; for more interpretable
loss values, the "main" part of your loss function should be first.
The recursion below applies if boundary == None, when it defaults
to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px and py:
p = tensor of shape (B, S+1, T+1), containing -infinity
p[b,0,0] = 0.0
# do the following in loop over s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
p[b,s,t-1] + py_sum[b,s,t-1])
(if s > 0 or t > 0)
return b[:][S][T]
This function lets you implement the above recursion efficiently, except
that it gives you a breakdown of the contribution from all the elements of
px and py separately. As noted above, the first element of the
sequence is "special".
"""
N = len(px)
assert len(py) == N and N > 0
B, S, T1 = px[0].shape
T = T1 - 1
assert py[0].shape == (B, S + 1, T)
assert px[0].dtype == py[0].dtype
px_cat = torch.stack(px, dim=0) # (N, B, S, T+1)
py_cat = torch.stack(py, dim=0) # (N, B, S+1, T)
px_tot = px_cat.sum(dim=0) # (B, S, T+1)
py_tot = py_cat.sum(dim=0) # (B, S+1, T)
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for [ s_begin, t_begin, s_end, t_end ] in boundary.to('cpu').tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
else:
boundary = torch.zeros(0, 0, dtype=torch.int64, device=px_tot.device)
px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
# The following assertions are for efficiency
assert px_tot.stride()[-1] == 1 and px_tot.ndim == 3
assert py_tot.stride()[-1] == 1 and py_tot.ndim == 3
p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)
# note, tot_probs is without grad.
tot_probs = _mutual_information_forward_dispatcher(px_tot, py_tot, boundary, p)
# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px_tot, py_tot,
boundary, p, ans_grad)
px_grad, py_grad = px_grad.reshape(1, B, -1), py_grad.reshape(1, B, -1)
px_cat, py_cat = px_cat.reshape(N, B, -1), py_cat.reshape(N, B, -1)
x_prods = _inner(px_grad, px_cat) # (N, B)
y_prods = _inner(py_grad, py_cat) # (N, B)
# If all the occupation counts were exactly 1.0 (i.e. no partial counts),
# "prods" should be equal to "tot_probs"; however, in general, "tot_probs"
# will be more positive due to the difference between log-of-sum and
# sum-of-log
prods = x_prods + y_prods # (N, B)
with torch.no_grad():
offset = tot_probs - prods.sum(dim=0) # (B,)
prods[0] += offset
return prods # (N, B)
#include <math.h> // for log1p, log1pf
#include <torch/extension.h>
inline double Exp(double x) {
return exp(x);
}
inline double Exp(float x) {
return expf(x);
}
// returns log(exp(x) + exp(y)).
inline double LogAdd(double x, double y) {
double diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= -1000) {
double res;
res = x + log1p(exp(diff));
return res;
}
return x; // return the larger one.
}
// returns log(exp(x) + exp(y)).
inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= -200) {
float res;
res = x + log1pf(expf(diff));
return res;
}
return x; // return the larger one.
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
// px: of shape [B, S, T+1] where
torch::Tensor mutual_information_cpu(torch::Tensor px,
torch::Tensor py,
torch::Tensor boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0),
S = px.size(1),
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cpu() &&
boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts);
bool has_boundary = (boundary.size(0) != 0);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>();
auto boundary_a = boundary.packed_accessor32<int64_t, 2>();
auto ans_a = ans.packed_accessor32<scalar_t, 1>();
for (int b = 0; b < B; b++) {
int s_begin, s_end, t_begin, t_end;
if (has_boundary) {
s_begin = boundary_a[b][0];
t_begin = boundary_a[b][1];
s_end = boundary_a[b][2];
t_end = boundary_a[b][3];
} else {
s_begin = 0;
t_begin = 0;
s_end = S;
t_end = T;
}
p_a[b][s_begin][t_begin] = 0.0;
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t)
p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
for (int s = s_begin + 1; s <= s_end; ++s) {
scalar_t p_s_t1 = p_a[b][s][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t) {
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a[b][s][t] = p_s_t1 = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
p_s_t1 + py_a[b][s][t - 1]);
}
}
ans_a[b] = p_a[b][s_end][t_end];
}
}));
return ans;
}
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std::vector<torch::Tensor> mutual_information_backward_cpu(
torch::Tensor px,
torch::Tensor py,
torch::Tensor boundary,
torch::Tensor p,
torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu()
&& ans_grad.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0),
S = px.size(1),
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cpu() &&
boundary.dtype() == torch::kInt64);
bool has_boundary = (boundary.size(0) != 0);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
torch::empty({B, S, T + 1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) :
torch::empty({B, S + 1, T}, opts));
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.packed_accessor32<scalar_t, 3>(),
// py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>(),
p_grad_a = p_grad.packed_accessor32<scalar_t, 3>(),
px_grad_a = px_grad.packed_accessor32<scalar_t, 3>(),
py_grad_a = py_grad.packed_accessor32<scalar_t, 3>();
auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>();
auto boundary_a = boundary.packed_accessor32<int64_t, 2>();
for (int b = 0; b < B; b++) {
int s_begin, s_end, t_begin, t_end;
if (has_boundary) {
s_begin = boundary_a[b][0];
t_begin = boundary_a[b][1];
s_end = boundary_a[b][2];
t_end = boundary_a[b][3];
} else {
s_begin = 0;
s_end = S;
t_begin = 0;
t_end = T;
}
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a[b][s_end][t_end] = ans_grad_a[b];
for (int s = s_end; s > s_begin; --s) {
for (int t = t_end; t > t_begin; --t) {
// The s,t indexes correspond to
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total = p_a[b][s][t],
term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t],
term1_grad = term1_deriv * grad,
term2_grad = term2_deriv * grad;
px_grad_a[b][s - 1][t] = term1_grad;
p_grad_a[b][s - 1][t] = term1_grad;
py_grad_a[b][s][t - 1] = term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
}
}
for (int t = t_end; t > t_begin; --t) {
// Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t this_p_grad = p_grad_a[b][s_begin][t];
p_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] = this_p_grad;
}
for (int s = s_end; s > s_begin; --s) {
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t this_p_grad = p_grad_a[b][s][t_begin];
p_grad_a[b][s - 1][t_begin] += this_p_grad;
px_grad_a[b][s - 1][t_begin] = this_p_grad;
}
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b];
if (fabs(grad_ratio - 1.0) > 0.01) {
printf("Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f\n",
(float)p_grad_a[b][s_begin][t_begin], (float)ans_grad_a[b]);
}
}
}
}));
// std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mutual_information_cpu", &mutual_information_cpu, "Integrated convolution forward function (CPU)");
m.def("mutual_information_backward_cpu", &mutual_information_backward_cpu, "Integrated convolution backward function (CPU)");
}
#include <torch/extension.h>
/*
Forward of mutual_information. See also """... """ comment of
`mutual_information` in mutual_information.py. This It is the core recursion
in the sequence-to-sequence mutual information computation.
Args:
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. Alternatively, may be
a tensor of shape [0][0] and type int64_t; the elements will
default to (0, 0, S, T).
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::Tensor boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor> mutual_information_backward_cuda(
torch::Tensor px,
torch::Tensor py,
torch::Tensor boundary,
torch::Tensor p,
torch::Tensor ans_grad,
bool overwrite_ans_grad);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mutual_information_cuda", &mutual_information_cuda, "Mutual information forward function (CUDA)");
m.def("mutual_information_backward_cuda", &mutual_information_backward_cuda, "Mutual information backward function (CUDA)");
}
# Caution: this will fail occasionally due to cutoffs not being quite large enough.
# As long as it passes most of the time, it's OK.
import random
import torch
from torch_mutual_information import mutual_information_recursion, joint_mutual_information_recursion
def test_mutual_information_basic():
print("Running test_mutual_information_basic()")
for _iter in range(100):
(B, S, T) = (random.randint(1, 10),
random.randint(1, 200),
random.randint(1, 200))
random_px = (random.random() < 0.2)
random_py = (random.random() < 0.2)
random_boundary = (random.random() < 0.7)
big_px = (random.random() < 0.2)
big_py = (random.random() < 0.2)
print(f"B, S, T = {B}, {S}, {T}, random_px={random_px}, random_py={random_py}, big_px={big_px}, big_py={big_py}, random_boundary={random_boundary}")
for dtype in [torch.float32, torch.float64]:
px_grads = []
py_grads = []
m_vals = []
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
if random_boundary:
def get_boundary_row():
s_begin = random.randint(0, S - 1)
t_begin = random.randint(0, T - 1)
s_end = random.randint(s_begin, S) # allow empty sequence
t_end = random.randint(t_begin, T) # allow empty sequence
return [s_begin, t_begin, s_end, t_end]
if device == torch.device('cpu'):
boundary = torch.tensor([ get_boundary_row() for _ in range(B) ],
dtype=torch.int64, device=device)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly or not.
if random.random() < 0.5:
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
else:
boundary = None
if device == torch.device('cpu'):
if random_px:
px = torch.randn(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
else:
px = torch.zeros(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if big_px:
px += 15.0
if random_py:
py = torch.randn(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
else:
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
#m = mutual_information_recursion(px, py, None)
m = mutual_information_recursion(px, py, boundary)
m2 = joint_mutual_information_recursion((px,), (py,), boundary)
m3 = joint_mutual_information_recursion((px * 0.5, px * 0.5), (py * 0.5, py * 0.5), boundary)
print("m3, before sum, = ", m3)
m3 = m3.sum(dim=0) # it is supposed to be identical only after
# summing over dim 0, corresponding to the
# sequence dim
print("m = ", m, ", size = ", m.shape)
print("m2 = ", m2, ", size = ", m2.shape)
print("m3 = ", m3, ", size = ", m3.shape)
assert torch.allclose(m, m2)
assert torch.allclose(m, m3)
#print("exp(m) = ", m.exp())
# the loop this is in checks that the CPU and CUDA versions give the same
# derivative; by randomizing which of m, m2 or m3 we backprop, we also
# ensure that the joint version of the code gives the same derivative
# as the regular version
scale = 3
if random.random() < 0.5:
(m.sum() * scale).backward()
elif random.random() < 0.5:
(m2.sum() * scale).backward()
else:
(m3.sum() * scale).backward()
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
px_grads.append(px.grad.to('cpu'))
py_grads.append(py.grad.to('cpu'))
m_vals.append(m.to('cpu'))
if not torch.allclose(m_vals[0], m_vals[1], atol=1.0e-02, rtol=1.0e-02):
print(f"m_vals differed CPU vs CUDA: {m_vals[0]} vs. {m_vals[1]}")
assert 0
if not torch.allclose(px_grads[0], px_grads[1], atol=1.0e-02, rtol=1.0e-02):
print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
assert 0
if not torch.allclose(py_grads[0], py_grads[1], atol=1.0e-02, rtol=1.0e-02):
print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
assert 0
def test_mutual_information_deriv():
print("Running test_mutual_information_deriv()")
for _iter in range(100):
(B, S, T) = (random.randint(1, 10),
random.randint(1, 200),
random.randint(1, 200))
random_px = (random.random() < 0.2)
random_py = (random.random() < 0.2)
random_boundary = (random.random() < 0.7)
big_px = (random.random() < 0.2)
big_py = (random.random() < 0.2)
print(f"B, S, T = {B}, {S}, {T}, random_px={random_px}, random_py={random_py}, big_px={big_px}, big_py={big_py}, random_boundary={random_boundary}")
for dtype in [torch.float32, torch.float64]:
#px_grads = []
#py_grads = []
#m_vals = []
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
if random_boundary:
def get_boundary_row():
s_begin = random.randint(0, S - 1)
t_begin = random.randint(0, T - 1)
s_end = random.randint(s_begin + 1, S)
t_end = random.randint(t_begin + 1, T)
return [s_begin, t_begin, s_end, t_end]
if device == torch.device('cpu'):
boundary = torch.tensor([ get_boundary_row() for _ in range(B) ],
dtype=torch.int64, device=device)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly or not.
if random.random() < 0.5:
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
else:
boundary = None
if device == torch.device('cpu'):
if random_px:
px = torch.randn(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
else:
px = torch.zeros(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if big_px:
px += 15.0
if random_py:
py = torch.randn(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
else:
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
m = mutual_information_recursion(px, py, boundary)
#print("m = ", m)
#print("exp(m) = ", m.exp())
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
#px_grads.append(px.grad.to('cpu'))
#py_grads.append(py.grad.to('cpu'))
#m_vals.append(m.to('cpu'))
m_grad = torch.randn(B, dtype=dtype, device=device)
m.backward(gradient=m_grad)
delta = 1.0e-04
delta_px = delta * torch.randn_like(px)
m2 = mutual_information_recursion(px + delta_px, py, boundary)
delta_m = m2 - m
observed_delta = (delta_m * m_grad).sum().to('cpu')
predicted_delta = (delta_px * px.grad).sum().to('cpu')
print(f"For px: observed,predicted objf changes are: {observed_delta},{predicted_delta}, absolute objf was {(m * m_grad).sum()}")
atol = 1.0e-02 if dtype == torch.float32 else 1.0e-04
rtol = 1.0e-02 if dtype == torch.float32 else 1.0e-04
if not torch.allclose(observed_delta, predicted_delta, atol=atol, rtol=rtol):
print(f"Error: observed and predicted delta too different.")
assert 0
delta_py = delta * torch.randn_like(py)
m2 = mutual_information_recursion(px, py + delta_py, boundary)
delta_m = m2 - m
observed_delta = (delta_m * m_grad).sum().to('cpu')
predicted_delta = (delta_py * py.grad).sum().to('cpu')
print(f"For py: observed,predicted objf changes are: {observed_delta},{predicted_delta}, absolute objf was {(m * m_grad).sum()}")
# if not torch.allclose(m_vals[0], m_vals[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"m_vals differed CPU vs CUDA: {m_vals[0]} vs. {m_vals[1]}")
# assert 0
# if not torch.allclose(px_grads[0], px_grads[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
# assert 0
# if not torch.allclose(py_grads[0], py_grads[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
# assert 0
if __name__ == "__main__":
#torch.set_printoptions(edgeitems=30)
test_mutual_information_basic()
test_mutual_information_deriv()
import os
import torch
from torch import Tensor
from typing import Tuple, Optional
from . mutual_information import mutual_information_recursion, joint_mutual_information_recursion
def get_rnnt_logprobs(lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just addition),
to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion(). This function is called from
rnnt_loss_simple(), but may be useful for other purposes.
Args:
lm: Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape:
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am: Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape:
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols: A LongTensor of shape [B][S], containing the symbols at each position
of the sequence, possibly including EOS
termination_symbol: The identity of the termination symbol, must be
in {0..C-1}
Returns: (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1]
py: logprobs, of shape [B][S+1][T]
in the recursion:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim== 3 and am.ndim == 3 and lm.shape[0] == am.shape[0] and lm.shape[2] == am.shape[2]
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
# subtracting am_max and lm_max is to ensure the probs are in a good range to do exp()
# without causing underflow or overflow.
am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1]
lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1]
am_probs = (am - am_max).exp()
lm_probs = (lm - lm_max).exp()
# normalizers: [B][S+1][T]
normalizers = (torch.matmul(lm_probs, am_probs.transpose(1, 2)) + 1.0e-20).log()
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T]
# px is the probs of the actual symbols..
px_am = torch.gather(am.unsqueeze(1).expand(B, S, T, C), dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1)).squeeze(-1) # [B][S][T]
px_am = torch.cat((px_am,
torch.full((B, S, 1), float('-inf'),
device=px_am.device, dtype=px_am.dtype)),
dim=2) # now: [B][S][T+1], index [:,:,T] has -inf..
px_lm = torch.gather(lm[:,:S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]
px = px_am + px_lm # [B][S][T+1], last slice indexed [:,:,T] is -inf
px[:,:,:T] -= normalizers[:,:S,:] # px: [B][S][T+1]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am = am[:,:,termination_symbol].unsqueeze(1) # [B][1][T]
py_lm = lm[:,:,termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers
return (px, py)
def rnnt_loss_simple(lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Tensor = None) -> Tensor:
"""
A simple case of the RNN-T loss, where the 'joiner' network is just addition.
Returns negated total loss value.
Args:
lm: language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes
am: acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols: the symbol sequences, a LongTensor of shape [B][S], and elements in {0..C-1}.
termination_symbol: the termination symbol, with 0 <= termination_symbol < C
boundary: a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Returns:
a Tensor of shape (B,), containing the NEGATED total RNN-T loss values
for each element of the batch (like log-probs of sequences).
"""
px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol)
return mutual_information_recursion(px, py, boundary)
def get_rnnt_logprobs_aux(lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just addition),
to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion(). This version allows you
to make the loss-function one of the form:
lm_only_scale * lm_probs +
am_only_scale * am_probs +
(1-lm_only_scale-am_only_scale) * combined_probs
where lm_probs and am_probs are the probabilities given the lm and acoustic model
independently.
This function is called from
rnnt_loss_aux(), but may be useful for other purposes.
Args:
lm: Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape:
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am: Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape:
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols: A LongTensor of shape [B][S], containing the symbols at each position
of the sequence, possibly including EOS
termination_symbol: The identity of the termination symbol, must be
in {0..C-1}
Returns: (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1]
py: logprobs, of shape [B][S+1][T]
in the recursion:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim== 3 and am.ndim == 3 and lm.shape[0] == am.shape[0] and lm.shape[2] == am.shape[2]
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
# all of the logprobs here are properly normalized. We test that
# this code is invariant to adding constants in the appropriate ways.
# subtracting am_max and lm_max is to ensure the probs are in a good range to do exp()
# without causing underflow or overflow.
am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1]
lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1]
am_probs = (am - am_max).exp() # [B][T][C]
lm_probs = (lm - lm_max).exp() # [B][S+1][C]
# normalizers: [B][S+1][T]
normalizers = (torch.matmul(lm_probs, am_probs.transpose(1, 2)) + 1.0e-20).log()
# normalizer per frame, if we take only the LM probs by themselves
lmonly_normalizers = lm_probs.sum(dim=2, keepdim=True) # lmonly_normalizers: [B][S+1][1]
unigram_lm = torch.mean(lm_probs / lmonly_normalizers, dim=(0,1), keepdim=True) + 1.0e-20 # [1][1][C]
amonly_normalizers = torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C)).reshape(B, T, 1).log() + am_max # [B][T][1]
amonly_normalizers = amonly_normalizers.transpose(1, 2) # [B][1][T]
unigram_lm = unigram_lm.log()
lmonly_normalizers = lmonly_normalizers.log() + lm_max # [B][S+1][1], log-normalizer, used for LM-only part of prob.
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T]
# px is the probs of the actual symbols (not yet normalized)..
px_am = torch.gather(am.unsqueeze(1).expand(B, S, T, C), dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1)).squeeze(-1) # [B][S][T]
px_am = torch.cat((px_am,
torch.full((B, S, 1), float('-inf'),
device=px_am.device, dtype=px_am.dtype)),
dim=2) # now: [B][S][T+1], index [:,:,T] has -inf..
px_lm = torch.gather(lm[:,:S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]
px_lm_unigram = torch.gather(unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]
px = px_am + px_lm # [B][S][T+1], last slice indexed [:,:,T] is -inf
px[:,:,:T] -= normalizers[:,:S,:] # px: [B][S][T+1]
px_amonly = px_am + px_lm_unigram # [B][S][T+1]
px_amonly[:,:,:T] -= amonly_normalizers
px_lmonly = px_lm - lmonly_normalizers[:,:S,:]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am = am[:,:,termination_symbol].unsqueeze(1) # [B][1][T]
py_lm = lm[:,:,termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers
py_lm_unigram = unigram_lm[0][0][termination_symbol] # scalar, normalized..
py_amonly = py_am + py_lm_unigram - amonly_normalizers # [B][S+1][T]
py_lmonly = py_lm - lmonly_normalizers # [B][S+1][T]
combined_scale = 1.0 - lm_only_scale - am_only_scale
# We need to avoid exact zeros in the scales because otherwise multiplying -inf
# by zero generates nan.
if lm_only_scale == 0.0:
lm_only_scale = 1.0e-20
if am_only_scale == 0.0:
am_only_scale = 1.0e-20
px_interp = px * combined_scale + px_lmonly * lm_only_scale + px_amonly * am_only_scale
py_interp = py * combined_scale + py_lmonly * lm_only_scale + py_amonly * am_only_scale
print("px_interp = ", px_interp)
print("py_interp = ", py_interp)
return (px_interp, py_interp)
def rnnt_loss_aux(lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1,
boundary: Tensor = None) -> Tensor:
"""
A simple case of the RNN-T loss, where the 'joiner' network is just addition.
Returns negated total loss value.
Args:
lm: language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes.
These are assumed to be well-normalized, in the sense that we could
use them as probabilities separately from the am scores
am: acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols: the symbol sequences, a LongTensor of shape [B][S], and elements in {0..C-1}.
termination_symbol: the termination symbol, with 0 <= termination_symbol < C
am_only_scale: the scale on the "AM-only" part of the loss, for which we use
an "averaged" LM (averaged over all histories, so effectively unigram).
boundary: a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Returns:
a Tensor of shape (B,), containing the NEGATED total RNN-T loss values
for each element of the batch (like log-probs of sequences).
"""
px, py = get_rnnt_logprobs_aux(lm, am, symbols, termination_symbol,
lm_only_scale, am_only_scale)
return mutual_information_recursion(px, py, boundary)
import random
import torch
from torch_mutual_information import mutual_information_recursion, joint_mutual_information_recursion, get_rnnt_logprobs, rnnt_loss_simple, rnnt_loss_aux
def test_rnnt_logprobs_basic():
print("Running test_rnnt_logprobs_basic()")
B = 1
S = 3
T = 4
C = 3
# lm: [B][S+1][C]
lm = torch.tensor([[[ 0, 0, 1 ], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], dtype=torch.float)
# am: [B][T][C]
am = torch.tensor([[[ 0, 1, 2], [0, 0, 0 ], [0, 2, 4 ], [0, 3, 3]]], dtype=torch.float)
# lm[:] = 0.0
# am[:] = 0.0
termination_symbol = 2
symbols = torch.tensor([[ 0, 1, 0 ] ], dtype=torch.long)
px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol)
assert px.shape == (B, S, T+1)
assert py.shape == (B, S+1, T)
assert symbols.shape == (B, S)
print("px = ", px)
print("py = ", py)
m = mutual_information_recursion(px, py)
print("m = ", m)
# should be invariant to adding a constant for any frame.
lm += torch.randn(B, S+1, 1)
am += torch.randn(B, T, 1)
m2 = rnnt_loss_simple(lm, am, symbols, termination_symbol, None)
print("m2 = ", m2)
device = torch.device('cuda')
m3 = rnnt_loss_simple(lm.to(device), am.to(device), symbols.to(device), termination_symbol, None)
print("m3 = ", m3)
device = torch.device('cuda')
m4 = rnnt_loss_aux(lm.to(device), am.to(device), symbols.to(device), termination_symbol,
lm_only_scale=0.0, am_only_scale=0.0, boundary=None)
print("m4 = ", m4)
assert torch.allclose(m, m2)
assert torch.allclose(m, m3.to('cpu'))
assert torch.allclose(m, m4.to('cpu'))
def test_rnnt_logprobs_aux():
print("Running test_rnnt_logprobs_aux()")
B = 1
S = 3
T = 4
C = 3
# lm: [B][S+1][C]
lm = torch.tensor([[[ 0, 0, 1 ], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], dtype=torch.float)
# am: [B][T][C]
am = torch.tensor([[[ 0, 1, 2], [0, 0, 0 ], [0, 2, 4 ], [0, 3, 3]]], dtype=torch.float)
termination_symbol = 2
symbols = torch.tensor([[ 0, 1, 0 ] ], dtype=torch.long)
device = torch.device('cuda')
m1 = rnnt_loss_aux(lm.to(device), am.to(device), symbols.to(device), termination_symbol,
lm_only_scale=0.0, am_only_scale=0.333, boundary=None)
print("m1 = ", m1)
# should be invariant to adding a constant for any frame.
lm += torch.randn(B, S+1, 1)
am += torch.randn(B, T, 1)
m2 = rnnt_loss_aux(lm.to(device), am.to(device), symbols.to(device), termination_symbol,
lm_only_scale=0.0, am_only_scale=0.333, boundary=None)
print("m2 = ", m2)
assert torch.allclose(m1, m2)
if __name__ == "__main__":
#torch.set_printoptions(edgeitems=30)
test_rnnt_logprobs_aux()
test_rnnt_logprobs_basic()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment