Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
d53e923b
Commit
d53e923b
authored
Mar 08, 2022
by
pkufool
Browse files
Move k2 rnnt_loss here
parent
b5828e2b
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2538 additions
and
1480 deletions
+2538
-1480
fast_rnnt/python/csrc/mutual_information.cu
fast_rnnt/python/csrc/mutual_information.cu
+67
-0
fast_rnnt/python/csrc/mutual_information.h
fast_rnnt/python/csrc/mutual_information.h
+28
-0
fast_rnnt/python/fast_rnnt/__init__.py
fast_rnnt/python/fast_rnnt/__init__.py
+15
-0
fast_rnnt/python/fast_rnnt/mutual_information.py
fast_rnnt/python/fast_rnnt/mutual_information.py
+419
-0
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+1162
-0
fast_rnnt/python/tests/CMakeLists.txt
fast_rnnt/python/tests/CMakeLists.txt
+26
-0
fast_rnnt/python/tests/mutual_information_test.py
fast_rnnt/python/tests/mutual_information_test.py
+310
-0
fast_rnnt/python/tests/rnnt_loss_test.py
fast_rnnt/python/tests/rnnt_loss_test.py
+511
-0
tests/requirements_test.txt
tests/requirements_test.txt
+0
-2
tests/test.py
tests/test.py
+0
-192
torch_mutual_information/__init__.py
torch_mutual_information/__init__.py
+0
-2
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+0
-311
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+0
-253
torch_mutual_information/mutual_information_cuda.cpp
torch_mutual_information/mutual_information_cuda.cpp
+0
-77
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+0
-240
torch_mutual_information/rnnt.py
torch_mutual_information/rnnt.py
+0
-304
torch_mutual_information/rnnt_test.py
torch_mutual_information/rnnt_test.py
+0
-99
No files found.
fast_rnnt/python/csrc/mutual_information.cu
0 → 100644
View file @
d53e923b
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
PYBIND11_MODULE
(
_fast_rnnt
,
m
)
{
m
.
doc
()
=
"Python wrapper for Mutual Information."
;
m
.
def
(
"mutual_information_forward"
,
[](
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
optional
<
torch
::
Tensor
>
boundary
,
torch
::
Tensor
p
)
->
torch
::
Tensor
{
if
(
px
.
device
().
is_cpu
())
{
return
fast_rnnt
::
MutualInformationCpu
(
px
,
py
,
boundary
,
p
);
}
else
{
#ifdef FT_WITH_CUDA
return
fast_rnnt
::
MutualInformationCuda
(
px
,
py
,
boundary
,
p
);
#else
//K2_LOG(FATAL) << "Failed to find native CUDA module, make sure "
//<< "that you compiled the code with K2_WITH_CUDA.";
return
torch
::
Tensor
();
#endif
}
},
py
::
arg
(
"px"
),
py
::
arg
(
"py"
),
py
::
arg
(
"boundary"
),
py
::
arg
(
"p"
));
m
.
def
(
"mutual_information_backward"
,
[](
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
optional
<
torch
::
Tensor
>
boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
)
->
std
::
vector
<
torch
::
Tensor
>
{
if
(
px
.
device
().
is_cpu
())
{
return
fast_rnnt
::
MutualInformationBackwardCpu
(
px
,
py
,
boundary
,
p
,
ans_grad
);
}
else
{
#ifdef FT_WITH_CUDA
return
fast_rnnt
::
MutualInformationBackwardCuda
(
px
,
py
,
boundary
,
p
,
ans_grad
,
true
);
#else
//K2_LOG(FATAL) << "Failed to find native CUDA module, make sure "
//<< "that you compiled the code with K2_WITH_CUDA.";
return
std
::
vector
<
torch
::
Tensor
>
();
#endif
}
},
py
::
arg
(
"px"
),
py
::
arg
(
"py"
),
py
::
arg
(
"boundary"
),
py
::
arg
(
"p"
),
py
::
arg
(
"ans_grad"
));
}
fast_rnnt/python/csrc/mutual_information.h
0 → 100644
View file @
d53e923b
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#include "pybind11/pybind11.h"
namespace
py
=
pybind11
;
#endif // FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
fast_rnnt/python/fast_rnnt/__init__.py
0 → 100644
View file @
d53e923b
from
.mutual_information
import
mutual_information_recursion
from
.mutual_information
import
joint_mutual_information_recursion
from
.rnnt_loss
import
do_rnnt_pruning
from
.rnnt_loss
import
get_rnnt_logprobs
from
.rnnt_loss
import
get_rnnt_logprobs_joint
from
.rnnt_loss
import
get_rnnt_logprobs_pruned
from
.rnnt_loss
import
get_rnnt_logprobs_smoothed
from
.rnnt_loss
import
get_rnnt_prune_ranges
from
.rnnt_loss
import
rnnt_loss
from
.rnnt_loss
import
rnnt_loss_pruned
from
.rnnt_loss
import
rnnt_loss_simple
from
.rnnt_loss
import
rnnt_loss_smoothed
fast_rnnt/python/fast_rnnt/mutual_information.py
0 → 100644
View file @
d53e923b
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey, Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
torch
import
_fast_rnnt
from
torch
import
Tensor
from
typing
import
Tuple
,
Optional
,
Sequence
,
Union
,
List
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
"""A recursion that is useful in computing mutual information between two
sequences of real vectors, but may be useful more generally in
sequence-to-sequence tasks where monotonic alignment between pairs of
sequences is desired.
"""
@
staticmethod
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
pxy_grads
:
List
[
Optional
[
torch
.
Tensor
]],
boundary
:
Optional
[
torch
.
Tensor
]
=
None
,
return_grad
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Computing mutual information between two sequences of real vectors.
Args:
px:
A torch.Tensor of some floating point type, with shape
``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the
length of the ``x`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols), and ``S`` is the
length of the ``y`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols). In the mutual
information application, ``px[b][s][t]`` would represent the
following log odds ratio; ignoring the b index on the right
to make the notation more
compact::
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an ``x`` value as opposed to a ``y`` value. In
practice it might be computed as ``a + b``, where ``a`` is the log
probability of choosing to extend the sequence of length ``(s,t)``
with an ``x`` as opposed to a ``y`` value; and ``b`` might in
practice be of the form::
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where ``N`` is the number of terms that the sum over ``t'``
included, which might include some or all of the other sequences as
well as this one.
Note:
we don't require ``px`` and py to be contiguous, but the
code assumes for optimization purposes that the ``T`` axis has
stride 1.
py:
A torch.Tensor of the same dtype as ``px``, with shape
``[B][S+1][T]``, representing::
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat ``x`` and ``y`` differently; the only
difference is that for optimization purposes we assume the last axis
(the ``t`` axis) has stride of 1; this is true if ``px`` and ``py``
are contiguous.
pxy_grads:
A List to store the return grads of ``px`` and ``py``
if return_grad == True.
Remain unchanged if return_grad == False.
See `this PR <https://github.com/k2-fsa/k2/pull/924>` for more
information about why we add this parameter.
Note:
the length of the list must be 2, where the first element
represents the grads of ``px`` and the second one represents
the grads of ``py``.
boundary:
If supplied, a torch.LongTensor of shape ``[B][4]``, where each
row contains ``[s_begin, t_begin, s_end, t_end]``,
with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T``
(this implies that empty sequences are allowed).
If not supplied, the values ``[0, 0, S, T]`` will be assumed.
These are the beginning and one-past-the-last positions in the
``x`` and ``y`` sequences respectively, and can be used if not
all sequences are
of the same length.
return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing
for the occupation probability is the output of the backward with a
``fake gradient`` the ``fake gradient`` is the same as the gradient
you'd get if you did
``torch.autograd.grad((scores.sum()), [px, py])``.
This is useful to implement the pruned version of rnnt loss.
Returns:
Returns a torch.Tensor of shape ``[B]``, containing the log of
the mutual information between the b'th pair of sequences. This is
defined by the following recursion on ``p[b,s,t]`` (where ``p``
is of shape ``[B,S+1,T+1]``), representing a mutual information
between sub-sequences of lengths ``s`` and ``t``::
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
where we handle edge cases by treating quantities with negative
indexes as **-infinity**. The extension to cases where the
boundaries are specified should be obvious; it just works on
shorter sequences with offsets into ``px`` and ``py``.
"""
(
B
,
S
,
T1
)
=
px
.
shape
T
=
py
.
shape
[
-
1
]
assert
T1
in
[
T
,
T
+
1
]
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
if
boundary
is
not
None
:
assert
boundary
.
shape
==
(
B
,
4
)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# the mutual information of the pair of subsequences of x and y that
# are of length s and t respectively. p[0][0] will be 0.0 and p[S][T]
# is the mutual information of the entire pair of sequences,
# i.e. of lengths S and T respectively.
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
ans
=
_fast_rnnt
.
mutual_information_forward
(
px
,
py
,
boundary
,
p
)
px_grad
,
py_grad
=
None
,
None
if
return_grad
or
px
.
requires_grad
or
py
.
requires_grad
:
ans_grad
=
torch
.
ones
(
B
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
(
px_grad
,
py_grad
)
=
_fast_rnnt
.
mutual_information_backward
(
px
,
py
,
boundary
,
p
,
ans_grad
)
ctx
.
save_for_backward
(
px_grad
,
py_grad
)
assert
len
(
pxy_grads
)
==
2
pxy_grads
[
0
]
=
px_grad
pxy_grads
[
1
]
=
py_grad
return
ans
@
staticmethod
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
,
None
,
None
]:
(
px_grad
,
py_grad
)
=
ctx
.
saved_tensors
(
B
,)
=
ans_grad
.
shape
ans_grad
=
ans_grad
.
reshape
(
B
,
1
,
1
)
# (B, 1, 1)
px_grad
*=
ans_grad
py_grad
*=
ans_grad
return
(
px_grad
,
py_grad
,
None
,
None
,
None
)
def
mutual_information_recursion
(
px
:
Tensor
,
py
:
Tensor
,
boundary
:
Optional
[
Tensor
]
=
None
,
return_grad
:
bool
=
False
,
)
->
Union
[
Tuple
[
Tensor
,
Tuple
[
Tensor
,
Tensor
]],
Tensor
]:
"""A recursion that is useful in computing mutual information between two
sequences of real vectors, but may be useful more generally in
sequence-to-sequence tasks where monotonic alignment between pairs of
sequences is desired. The definitions of the arguments are definitions that
would be used when computing this type of mutual information, but you can
also view them as arbitrary quantities and just make use of the formula
computed by this function.
Args:
px:
A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``,
where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence
(including representations of ``EOS`` symbols but not ``BOS`` symbols),
and ``S`` is the length of the ``y`` sequence (including representations
of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information
application, ``px[b][s][t]`` would represent the following log odds
ratio; ignoring the b index on the right to make the notation more
compact::
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an ``x`` value as opposed to a ``y`` value. In
practice it might be computed as ``a + b``, where ``a`` is the log
probability of choosing to extend the sequence of length ``(s,t)``
with an ``x`` as opposed to a ``y`` value; and ``b`` might in practice
be of the form::
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where ``N`` is the number of terms that the sum over ``t'`` included,
which might include some or all of the other sequences as well as this
one.
Note:
we don't require ``px`` and py to be contiguous, but the
code assumes for optimization purposes that the ``T`` axis has
stride 1.
py:
A torch.Tensor of the same dtype as ``px``, with shape ``[B][S+1][T]``,
representing::
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat ``x`` and ``y`` differently; the only
difference is that for optimization purposes we assume the last axis
(the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` are
contiguous.
boundary:
If supplied, a torch.LongTensor of shape ``[B][4]``, where each
row contains ``[s_begin, t_begin, s_end, t_end]``,
with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T``
(this implies that empty sequences are allowed).
If not supplied, the values ``[0, 0, S, T]`` will be assumed.
These are the beginning and one-past-the-last positions in the ``x`` and
``y`` sequences respectively, and can be used if not all sequences are
of the same length.
return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing for the
occupation probability is the output of the backward with a
``fake gradient`` the ``fake gradient`` is the same as the gradient
you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``.
This is useful to implement the pruned version of rnnt loss.
Returns:
Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual
information between the b'th pair of sequences. This is defined by
the following recursion on ``p[b,s,t]`` (where ``p`` is of shape
``[B,S+1,T+1]``), representing a mutual information between sub-sequences
of lengths ``s`` and ``t``::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
where we handle edge cases by treating quantities with negative indexes
as **-infinity**. The extension to cases where the boundaries are
specified should be obvious; it just works on shorter sequences with
offsets into ``px`` and ``py``.
"""
assert
px
.
ndim
==
3
B
,
S
,
T1
=
px
.
shape
T
=
py
.
shape
[
-
1
]
assert
px
.
shape
[
-
1
]
in
[
T
,
T
+
1
]
# if T, then "modified".
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
px
.
dtype
==
py
.
dtype
if
boundary
is
not
None
:
assert
boundary
.
dtype
==
torch
.
int64
assert
boundary
.
shape
==
(
B
,
4
)
for
s_begin
,
t_begin
,
s_end
,
t_end
in
boundary
.
tolist
():
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
# The following assertions are for efficiency
assert
px
.
is_contiguous
()
assert
py
.
is_contiguous
()
pxy_grads
=
[
None
,
None
]
scores
=
MutualInformationRecursionFunction
.
apply
(
px
,
py
,
pxy_grads
,
boundary
,
return_grad
)
px_grad
,
py_grad
=
pxy_grads
return
(
scores
,
(
px_grad
,
py_grad
))
if
return_grad
else
scores
def
_inner_product
(
a
:
Tensor
,
b
:
Tensor
)
->
Tensor
:
"""
Does inner product on the last dimension, with expected broadcasting,
i.e. equivalent to (a * b).sum(dim=-1)
without creating a large temporary.
"""
assert
a
.
shape
[
-
1
]
==
b
.
shape
[
-
1
]
# The last dim must be equal
a
=
a
.
unsqueeze
(
-
2
)
# (..., 1, K)
b
=
b
.
unsqueeze
(
-
1
)
# (..., K, 1)
c
=
torch
.
matmul
(
a
,
b
)
# (..., 1, 1)
return
c
.
squeeze
(
-
1
).
squeeze
(
-
1
)
def
joint_mutual_information_recursion
(
px
:
Sequence
[
Tensor
],
py
:
Sequence
[
Tensor
],
boundary
:
Optional
[
Tensor
]
=
None
,
)
->
Sequence
[
Tensor
]:
"""A recursion that is useful for modifications of RNN-T and similar loss
functions, where the recursion probabilities have a number of terms and you
want them reported separately. See mutual_information_recursion() for more
documentation of the basic aspects of this.
Args:
px:
a sequence of Tensors, each of the same shape [B][S][T+1]
py:
a sequence of Tensor, each of the same shape [B][S+1][T],
the sequence must be the same length as px.
boundary:
optionally, a LongTensor of shape [B][4] containing rows
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S
and 0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
These are the beginning and one-past-the-last positions in the x
and y sequences respectively, and can be used if not all
sequences are of the same length.
Returns:
a Tensor of shape (len(px), B),
whose sum over dim 0 is the total log-prob of the recursion mentioned
below, per sequence. The first element of the sequence of length len(px)
is "special", in that it has an offset term reflecting the difference
between sum-of-log and log-of-sum; for more interpretable loss values,
the "main" part of your loss function should be first.
The recursion below applies if boundary == None, when it defaults
to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px
and py::
p = tensor of shape (B, S+1, T+1), containing -infinity
p[b,0,0] = 0.0
# do the following in loop over s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
p[b,s,t-1] + py_sum[b,s,t-1])
(if s > 0 or t > 0)
return b[:][S][T]
This function lets you implement the above recursion efficiently, except
that it gives you a breakdown of the contribution from all the elements of
px and py separately. As noted above, the first element of the
sequence is "special".
"""
N
=
len
(
px
)
assert
len
(
py
)
==
N
and
N
>
0
B
,
S
,
T1
=
px
[
0
].
shape
T
=
py
[
0
].
shape
[
2
]
assert
T1
in
[
T
,
T
+
1
]
# T if modified...
assert
py
[
0
].
shape
==
(
B
,
S
+
1
,
T
)
assert
px
[
0
].
dtype
==
py
[
0
].
dtype
px_cat
=
torch
.
stack
(
px
,
dim
=
0
)
# (N, B, S, T+1) if !modified,(N, B, S, T) if modified.
py_cat
=
torch
.
stack
(
py
,
dim
=
0
)
# (N, B, S+1, T)
px_tot
=
px_cat
.
sum
(
dim
=
0
)
# (B, S, T+1)
py_tot
=
py_cat
.
sum
(
dim
=
0
)
# (B, S+1, T)
if
boundary
is
not
None
:
assert
boundary
.
dtype
==
torch
.
int64
assert
boundary
.
shape
==
(
B
,
4
)
for
s_begin
,
t_begin
,
s_end
,
t_end
in
boundary
.
tolist
():
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
px_tot
,
py_tot
=
px_tot
.
contiguous
(),
py_tot
.
contiguous
()
# The following assertions are for efficiency
assert
px_tot
.
ndim
==
3
assert
py_tot
.
ndim
==
3
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px_tot
.
device
,
dtype
=
px_tot
.
dtype
)
# note, tot_probs is without grad.
tot_probs
=
_fast_rnnt
.
mutual_information_forward
(
px_tot
,
py_tot
,
boundary
,
p
)
# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad
=
torch
.
ones
(
B
,
device
=
px_tot
.
device
,
dtype
=
px_tot
.
dtype
)
(
px_grad
,
py_grad
)
=
_fast_rnnt
.
mutual_information_backward
(
px_tot
,
py_tot
,
boundary
,
p
,
ans_grad
)
px_grad
=
px_grad
.
reshape
(
1
,
B
,
-
1
)
py_grad
=
py_grad
.
reshape
(
1
,
B
,
-
1
)
px_cat
=
px_cat
.
reshape
(
N
,
B
,
-
1
)
py_cat
=
py_cat
.
reshape
(
N
,
B
,
-
1
)
# get rid of -inf, would generate nan on product with 0
px_cat
=
px_cat
.
clamp
(
min
=
torch
.
finfo
(
px_cat
.
dtype
).
min
)
py_cat
=
py_cat
.
clamp
(
min
=
torch
.
finfo
(
py_cat
.
dtype
).
min
)
x_prods
=
_inner_product
(
px_grad
,
px_cat
)
# (N, B)
y_prods
=
_inner_product
(
py_grad
,
py_cat
)
# (N, B)
# If all the occupation counts were exactly 1.0 (i.e. no partial counts),
# "prods" should be equal to "tot_probs"; however, in general, "tot_probs"
# will be more positive due to the difference between log-of-sum and
# sum-of-log
prods
=
x_prods
+
y_prods
# (N, B)
with
torch
.
no_grad
():
offset
=
tot_probs
-
prods
.
sum
(
dim
=
0
)
# (B,)
prods
[
0
]
+=
offset
return
prods
# (N, B)
fast_rnnt/python/fast_rnnt/rnnt_loss.py
0 → 100644
View file @
d53e923b
# Copyright 2021 Xiaomi Corp. (author: Daniel Povey, Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
k2
import
torch
from
torch
import
Tensor
from
typing
import
Optional
,
Tuple
,
Union
from
.mutual_information
import
mutual_information_recursion
def
fix_for_boundary
(
px
:
Tensor
,
boundary
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
None. If boundary == None and modified == False, px[:,:,-1] will
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
to be -infinity.
Args:
px: a Tensor of of shape [B][S][T+1] (this function is only
called if modified == False, see other docs for `modified`)
px is modified in-place and returned.
boundary: None, or a Tensor of shape [B][3] containing
[s_begin, t_begin, s_end, t_end]; we need only t_end.
"""
if
boundary
is
None
:
return
px
B
,
S
,
T1
=
px
.
shape
boundary
=
boundary
[:,
3
].
reshape
(
B
,
1
,
1
).
expand
(
B
,
S
,
T1
)
return
px
.
scatter_
(
dim
=
2
,
index
=
boundary
,
value
=
float
(
"-inf"
))
def
get_rnnt_logprobs
(
lm
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Tensor
]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just
addition), to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This function is called from rnnt_loss_simple(), but may be useful for
other purposes.
Args:
lm:
Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape::
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am:
Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape::
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
of length s and t respectively. px[b][s][t] represents the
probability of extending the subsequences of length (s,t) by one in
the s direction, given the particular symbol, and py[b][s][t]
represents the probability of extending the subsequences of length
(s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert
lm
.
ndim
==
3
assert
am
.
ndim
==
3
assert
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
assert
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
am_max
,
_
=
torch
.
max
(
am
,
dim
=
2
,
keepdim
=
True
)
# am_max: [B][T][1]
lm_max
,
_
=
torch
.
max
(
lm
,
dim
=
2
,
keepdim
=
True
)
# lm_max: [B][S+1][1]
am_probs
=
(
am
-
am_max
).
exp
()
lm_probs
=
(
lm
-
lm_max
).
exp
()
# normalizers: [B][S+1][T]
normalizers
=
(
torch
.
matmul
(
lm_probs
,
am_probs
.
transpose
(
1
,
2
))
+
torch
.
finfo
(
am_probs
.
dtype
).
tiny
).
log
()
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers
=
normalizers
+
lm_max
+
am_max
.
transpose
(
1
,
2
)
# [B][S+1][T]
# px is the probs of the actual symbols..
px_am
=
torch
.
gather
(
am
.
unsqueeze
(
1
).
expand
(
B
,
S
,
T
,
C
),
dim
=
3
,
index
=
symbols
.
reshape
(
B
,
S
,
1
,
1
).
expand
(
B
,
S
,
T
,
1
),
).
squeeze
(
-
1
)
# [B][S][T]
if
not
modified
:
px_am
=
torch
.
cat
(
(
px_am
,
torch
.
full
(
(
B
,
S
,
1
),
float
(
"-inf"
),
device
=
px_am
.
device
,
dtype
=
px_am
.
dtype
,
),
),
dim
=
2
,
)
# now: [B][S][T+1], index [:,:,T] has -inf..
px_lm
=
torch
.
gather
(
lm
[:,
:
S
],
dim
=
2
,
index
=
symbols
.
unsqueeze
(
-
1
)
)
# [B][S][1]
px
=
px_am
+
px_lm
# [B][S][T+1], last slice with indexes out of
# boundary is -inf
px
[:,
:,
:
T
]
-=
normalizers
[:,
:
S
,
:]
# px: [B][S][T+1]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am
=
am
[:,
:,
termination_symbol
].
unsqueeze
(
1
)
# [B][1][T]
py_lm
=
lm
[:,
:,
termination_symbol
].
unsqueeze
(
2
)
# [B][S+1][1]
py
=
py_am
+
py_lm
-
normalizers
if
not
modified
:
px
=
fix_for_boundary
(
px
,
boundary
)
return
(
px
,
py
)
def
rnnt_loss_simple
(
lm
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
reduction
:
Optional
[
str
]
=
"mean"
,
return_grad
:
bool
=
False
,
)
->
Union
[
Tensor
,
Tuple
[
Tensor
,
Tuple
[
Tensor
,
Tensor
]]]:
"""A simple case of the RNN-T loss, where the 'joiner' network is just
addition.
Args:
lm:
language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes
am:
acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols:
the symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
return_grad:
Whether to return grads of px and py, this grad standing for the
occupation probability is the output of the backward with a
`fake gradient`, the `fake gradient` is the same as the gradient you'd
get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the
loss here is the loss with reduction "none".
This is useful to implement the pruned version of rnnt loss.
Returns:
If return_grad is False, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch if reduction equals
to "none", otherwise a scalar with the reduction applied.
If return_grad is True, the grads of px and py, which is the output of
backward with a `fake gradient`(see above), will be returned too. And the
returned value will be a tuple like (loss, (px_grad, py_grad)).
"""
px
,
py
=
get_rnnt_logprobs
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
modified
=
modified
,
)
scores_and_grads
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
,
return_grad
=
return_grad
)
negated_loss
=
scores_and_grads
[
0
]
if
return_grad
else
scores_and_grads
if
reduction
==
"none"
:
loss
=
-
negated_loss
elif
reduction
==
"mean"
:
loss
=
-
torch
.
mean
(
negated_loss
)
elif
reduction
==
"sum"
:
loss
=
-
torch
.
sum
(
negated_loss
)
else
:
assert
(
False
),
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
return
(
loss
,
scores_and_grads
[
1
])
if
return_grad
else
loss
def
get_rnnt_logprobs_joint
(
logits
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Tensor
]:
"""Reduces RNN-T problem to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This function is called from rnnt_loss().
Args:
logits:
The output of joiner network, with shape (B, T, S + 1, C),
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
(px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1]
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert
logits
.
ndim
==
4
(
B
,
T
,
S1
,
C
)
=
logits
.
shape
S
=
S1
-
1
assert
symbols
.
shape
==
(
B
,
S
)
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
normalizers
=
normalizers
.
permute
((
0
,
2
,
1
))
px
=
torch
.
gather
(
logits
,
dim
=
3
,
index
=
symbols
.
reshape
(
B
,
1
,
S
,
1
).
expand
(
B
,
T
,
S
,
1
)
).
squeeze
(
-
1
)
px
=
px
.
permute
((
0
,
2
,
1
))
if
not
modified
:
px
=
torch
.
cat
(
(
px
,
torch
.
full
(
(
B
,
S
,
1
),
float
(
"-inf"
),
device
=
px
.
device
,
dtype
=
px
.
dtype
),
),
dim
=
2
,
)
# now: [B][S][T+1], index [:,:,T] has -inf..
px
[:,
:,
:
T
]
-=
normalizers
[:,
:
S
,
:]
py
=
(
logits
[:,
:,
:,
termination_symbol
].
permute
((
0
,
2
,
1
)).
clone
()
)
# [B][S+1][T]
py
-=
normalizers
px
=
px
.
contiguous
()
py
=
py
.
contiguous
()
if
not
modified
:
px
=
fix_for_boundary
(
px
,
boundary
)
return
(
px
,
py
)
def
rnnt_loss
(
logits
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
reduction
:
Optional
[
str
]
=
"mean"
,
)
->
Tensor
:
"""A normal RNN-T loss, which uses a 'joiner' network output as input,
i.e. a 4 dimensions tensor.
Args:
logits:
The output of joiner network, with shape (B, T, S + 1, C),
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes
symbols:
The symbol sequences, a LongTensor of shape [B][S], and elements
in {0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
Returns:
If recursion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar
with the reduction applied.
"""
px
,
py
=
get_rnnt_logprobs_joint
(
logits
=
logits
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
modified
=
modified
,
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
if
reduction
==
"none"
:
return
-
negated_loss
elif
reduction
==
"mean"
:
return
-
torch
.
mean
(
negated_loss
)
elif
reduction
==
"sum"
:
return
-
torch
.
sum
(
negated_loss
)
else
:
assert
(
False
),
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
def
_adjust_pruning_lower_bound
(
s_begin
:
torch
.
Tensor
,
s_range
:
int
)
->
torch
.
Tensor
:
"""Adjust s_begin (pruning lower bound) to make it satisfied the following
constrains
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- start with symbol 0 at first frame.
- s_begin[i + 1] - s_begin[i] < s_range, whicn means that we can't skip
any symbols.
To make it monotonic increasing, we can use `monotonic_lower_bound` function
in k2, which guarantee `s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`, the initial `min_value` set to
`inf`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constrain is a little tricky. We first transform `s_begin` with
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
then we make the transformed `s_begin` monotonic increasing, after that,
we transform back `s_begin` with the same formula as the previous
transformation. The idea is: if we want to make
`s_begin[i + 1] - s_begin[i] < s_range` we only need to make
`-(s_begin[i] - i * (s_range - 1))` a non-decreasing array. Proof:
-(s_begin[i] - i * (s_range - 1)) <= -(s_begin[i + 1] - (i + 1) * (s_range - 1))
-s_begin[i] <= -s_begin[i + 1] + (i + 1) * (s_range - 1) - i * (s_range - 1)
-s_begin[i] <= -s_begin[i + 1] + s_range - 1
s_begin[i + 1] - s_begin[i] <= s_range - 1
s_begin[i + 1] - s_begin[i] < s_range
The above transformation can not guarantee the start symbol to be 0, so we
have to make all the elements that less than 0 to be 0 before transforming
back the `s_begin`.
"""
# s_begin (B, T)
(
B
,
T
)
=
s_begin
.
shape
s_begin
=
k2
.
monotonic_lower_bound
(
s_begin
)
# do the magic transformation
s_begin
=
-
(
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
)
# make the transformed tensor to be non-decreasing
s_begin
=
k2
.
monotonic_lower_bound
(
s_begin
)
# make start symbol to be zero.
s_begin
=
torch
.
where
(
s_begin
<
0
,
0
,
s_begin
)
# do the magic transformation again to recover s_begin
s_begin
=
-
(
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
)
return
s_begin
def
get_rnnt_prune_ranges
(
px_grad
:
torch
.
Tensor
,
py_grad
:
torch
.
Tensor
,
boundary
:
torch
.
Tensor
,
s_range
:
int
,
)
->
torch
.
Tensor
:
"""Get the pruning ranges of normal rnnt loss according to the grads
of px and py returned by mutual_information_recursion.
For each sequence with T frames, we will generate a tensor with the shape of
(T, s_range) containing the information that which symbols will be token
into consideration for each frame. For example, here is a sequence with 10
frames and the corresponding symbols are `[A B C D E F]`, if the s_range
equals 3, one possible ranges tensor will be::
[[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [1, 2, 3],
[1, 2, 3], [1, 2, 3], [3, 4, 5], [3, 4, 5], [3, 4, 5]]
which means we only consider `[A B C]` at frame 0, 1, 2, 3, and `[B C D]`
at frame 4, 5, 6, `[D E F]` at frame 7, 8, 9.
We can only consider limited number of symbols because frames and symbols
are monotonic aligned, theoretically it can only generate particular range
of symbols given a particular frame.
Note:
For the generated tensor ranges, ranges[:, 0] is a monotonic increasing
tensor from 0 to `len(symbols)` and it satisfies
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any
symbols.
Args:
px_grad:
The gradient of px, see docs in `mutual_information_recursion` for more
details of px.
py_grad:
The gradient of py, see docs in `mutual_information_recursion` for more
details of py.
boundary:
a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame]
s_range:
How many symbols to keep for each frame.
Returns:
A tensor contains the kept symbols indexes for each frame, with shape
(B, T, s_range).
"""
(
B
,
S
,
T1
)
=
px_grad
.
shape
T
=
py_grad
.
shape
[
-
1
]
assert
T1
in
[
T
,
T
+
1
]
assert
py_grad
.
shape
==
(
B
,
S
+
1
,
T
)
assert
boundary
.
shape
==
(
B
,
4
)
assert
s_range
>=
1
if
s_range
>
S
:
s_range
=
S
px_pad
=
torch
.
zeros
((
B
,
1
,
T1
),
dtype
=
px_grad
.
dtype
,
device
=
px_grad
.
device
)
py_pad
=
torch
.
zeros
(
(
B
,
S
+
1
,
1
),
dtype
=
py_grad
.
dtype
,
device
=
py_grad
.
device
)
py_grad_padded
=
py_grad
if
T1
==
T
else
torch
.
cat
((
py_grad
,
py_pad
),
dim
=
2
)
tot_grad
=
(
torch
.
cat
((
px_grad
,
px_pad
),
dim
=
1
)
+
py_grad_padded
)
# (B, S + 1, T1)
tot_grad
=
torch
.
cat
(
(
torch
.
zeros
(
(
B
,
1
,
T1
),
dtype
=
tot_grad
.
dtype
,
device
=
tot_grad
.
device
),
tot_grad
,
),
dim
=
1
,
)
tot_grad
=
torch
.
cumsum
(
tot_grad
,
dim
=
1
)
diff_grad
=
tot_grad
[:,
s_range
:,
:]
-
tot_grad
[:,
0
:
-
s_range
,
:]
s_begin
=
torch
.
argmax
(
diff_grad
,
dim
=
1
)
s_begin
=
s_begin
[:,
:
T
]
# Handle the values of s_begin in padding positions.
# -1 here means we fill the position of the last frame of real data with
# padding value which is `len(symbols) - s_range + 1`.
# This is to guarantee that we reach the last symbol at last frame of real
# data.
mask
=
torch
.
arange
(
0
,
T
,
device
=
px_grad
.
device
).
reshape
(
1
,
T
).
expand
(
B
,
T
)
mask
=
mask
<
boundary
[:,
3
].
reshape
(
B
,
1
)
-
1
s_begin_padding
=
boundary
[:,
2
].
reshape
(
B
,
1
)
-
s_range
+
1
# handle the cases when `len(symbols) < s_range`
s_begin_padding
=
torch
.
where
(
s_begin_padding
>=
0
,
s_begin_padding
,
0
)
s_begin
=
torch
.
where
(
mask
,
s_begin
,
s_begin_padding
)
# adjusting lower bound to make it satisfied some constrains, see docs in
# `adjust_pruning_lower_bound` for more details of these constrains.
# T1 == T here means we are using the modified version of transducer,
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# it only emits one symbol per frame.
s_begin
=
_adjust_pruning_lower_bound
(
s_begin
,
2
if
T1
==
T
else
s_range
)
ranges
=
s_begin
.
reshape
((
B
,
T
,
1
)).
expand
((
B
,
T
,
s_range
))
+
torch
.
arange
(
s_range
,
device
=
px_grad
.
device
)
return
ranges
def
do_rnnt_pruning
(
am
:
torch
.
Tensor
,
lm
:
torch
.
Tensor
,
ranges
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Prune the output of encoder(am) output and prediction network(lm)
output of RNNT.
Args:
am:
The encoder output, with shape (B, T, C)
lm:
The prediction network output, with shape (B, S + 1, C)
ranges:
A tensor containing the symbol indexes for each frame that we want to
keep. Its shape is (B, T, s_range), see the docs in
`get_rnnt_prune_ranges` for more details of this tensor.
Returns:
Return the pruned am and lm with shape (B, T, s_range, C)
"""
# am (B, T, C)
# lm (B, S + 1, C)
# ranges (B, T, s_range)
assert
ranges
.
shape
[
0
]
==
am
.
shape
[
0
]
assert
ranges
.
shape
[
0
]
==
lm
.
shape
[
0
]
assert
am
.
shape
[
1
]
==
ranges
.
shape
[
1
]
(
B
,
T
,
s_range
)
=
ranges
.
shape
(
B
,
S1
,
C
)
=
lm
.
shape
S
=
S1
-
1
# (B, T, s_range, C)
am_pruning
=
am
.
unsqueeze
(
2
).
expand
((
B
,
T
,
s_range
,
C
))
# (B, T, s_range, C)
lm_pruning
=
torch
.
gather
(
lm
.
unsqueeze
(
1
).
expand
((
B
,
T
,
S
+
1
,
C
)),
dim
=
2
,
index
=
ranges
.
reshape
((
B
,
T
,
s_range
,
1
)).
expand
((
B
,
T
,
s_range
,
C
)),
)
return
am_pruning
,
lm_pruning
def
_roll_by_shifts
(
src
:
torch
.
Tensor
,
shifts
:
torch
.
LongTensor
):
"""Roll tensor with different shifts for each row.
Note:
We assume the src is a 3 dimensions tensor and roll the last dimension.
Example:
>>> src = torch.arange(15).reshape((1,3,5))
>>> src
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]]])
>>> shift = torch.tensor([[1, 2, 3]])
>>> shift
tensor([[1, 2, 3]])
>>> _roll_by_shifts(src, shift)
tensor([[[ 4, 0, 1, 2, 3],
[ 8, 9, 5, 6, 7],
[12, 13, 14, 10, 11]]])
"""
assert
src
.
dim
()
==
3
(
B
,
T
,
S
)
=
src
.
shape
assert
shifts
.
shape
==
(
B
,
T
)
index
=
(
torch
.
arange
(
S
,
device
=
src
.
device
)
.
view
((
1
,
S
))
.
repeat
((
T
,
1
))
.
repeat
((
B
,
1
,
1
))
)
index
=
(
index
-
shifts
.
reshape
(
B
,
T
,
1
))
%
S
return
torch
.
gather
(
src
,
2
,
index
)
def
get_rnnt_logprobs_pruned
(
logits
:
Tensor
,
symbols
:
Tensor
,
ranges
:
Tensor
,
termination_symbol
:
int
,
boundary
:
Tensor
,
modified
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Tensor
]:
"""Construct px, py for mutual_information_recursion with pruned output.
Args:
logits:
The pruned output of joiner network, with shape (B, T, s_range, C)
symbols:
The symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
Return the px (B, S, T) if modified else (B, S, T + 1) and
py (B, S + 1, T) needed by mutual_information_recursion.
"""
# logits (B, T, s_range, C)
# symbols (B, S)
# ranges (B, T, s_range)
assert
logits
.
ndim
==
4
(
B
,
T
,
s_range
,
C
)
=
logits
.
shape
assert
ranges
.
shape
==
(
B
,
T
,
s_range
)
(
B
,
S
)
=
symbols
.
shape
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
symbols_with_terminal
=
torch
.
cat
(
(
symbols
,
torch
.
tensor
(
[
termination_symbol
]
*
B
,
dtype
=
torch
.
int64
,
device
=
symbols
.
device
,
).
reshape
((
B
,
1
)),
),
dim
=
1
,
)
# (B, T, s_range)
pruned_symbols
=
torch
.
gather
(
symbols_with_terminal
.
unsqueeze
(
1
).
expand
((
B
,
T
,
S
+
1
)),
dim
=
2
,
index
=
ranges
,
)
# (B, T, s_range)
px
=
torch
.
gather
(
logits
,
dim
=
3
,
index
=
pruned_symbols
.
reshape
(
B
,
T
,
s_range
,
1
)
).
squeeze
(
-
1
)
px
=
px
-
normalizers
# (B, T, S) with index larger than s_range in dim 2 fill with -inf
px
=
torch
.
cat
(
(
px
,
torch
.
full
(
(
B
,
T
,
S
+
1
-
s_range
),
float
(
"-inf"
),
device
=
px
.
device
,
dtype
=
px
.
dtype
,
),
),
dim
=
2
,
)
# (B, T, S) with index out of s_range in dim 2 fill with -inf
px
=
_roll_by_shifts
(
px
,
ranges
[:,
:,
0
])[:,
:,
:
S
]
px
=
px
.
permute
((
0
,
2
,
1
))
if
not
modified
:
px
=
torch
.
cat
(
(
px
,
torch
.
full
(
(
B
,
S
,
1
),
float
(
"-inf"
),
device
=
px
.
device
,
dtype
=
px
.
dtype
),
),
dim
=
2
,
)
# now: [B][S][T+1], index [:,:,T] has -inf..
py
=
logits
[:,
:,
:,
termination_symbol
].
clone
()
# (B, T, s_range)
py
=
py
-
normalizers
# (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf
py
=
torch
.
cat
(
(
py
,
torch
.
full
(
(
B
,
T
,
S
+
1
-
s_range
),
float
(
"-inf"
),
device
=
py
.
device
,
dtype
=
py
.
dtype
,
),
),
dim
=
2
,
)
# (B, T, S + 1) with index out of s_range in dim 2 fill with -inf
py
=
_roll_by_shifts
(
py
,
ranges
[:,
:,
0
])
# (B, S + 1, T)
py
=
py
.
permute
((
0
,
2
,
1
))
px
=
px
.
contiguous
()
py
=
py
.
contiguous
()
if
not
modified
:
px
=
fix_for_boundary
(
px
,
boundary
)
return
(
px
,
py
)
def
rnnt_loss_pruned
(
logits
:
Tensor
,
symbols
:
Tensor
,
ranges
:
Tensor
,
termination_symbol
:
int
,
boundary
:
Tensor
=
None
,
modified
:
bool
=
False
,
reduction
:
Optional
[
str
]
=
"mean"
,
)
->
Tensor
:
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
s_range means the symbols number kept for each frame.
Args:
logits:
The pruned output of joiner network, with shape (B, T, s_range, C),
i.e. batch, time_seq_len, prune_range, num_classes
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
Returns:
If recursion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar
with the reduction applied.
"""
px
,
py
=
get_rnnt_logprobs_pruned
(
logits
=
logits
,
symbols
=
symbols
,
ranges
=
ranges
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
modified
=
modified
,
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
if
reduction
==
"none"
:
return
-
negated_loss
elif
reduction
==
"mean"
:
return
-
torch
.
mean
(
negated_loss
)
elif
reduction
==
"sum"
:
return
-
torch
.
sum
(
negated_loss
)
else
:
assert
(
False
),
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
def
get_rnnt_logprobs_smoothed
(
lm
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
lm_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Tensor
]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just
addition), to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This version allows you to make the loss-function one of the form::
lm_only_scale * lm_probs +
am_only_scale * am_probs +
(1-lm_only_scale-am_only_scale) * combined_probs
where lm_probs and am_probs are the probabilities given the lm and acoustic
model independently.
This function is called from
:func:`rnnt_loss_smoothed`, but may be useful for other purposes.
Args:
lm:
Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape::
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am:
Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape::
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
lm_only_scale:
the scale on the "LM-only" part of the loss.
am_only_scale:
the scale on the "AM-only" part of the loss, for which we use
an "averaged" LM (averaged over all histories, so effectively unigram).
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
of length s and t respectively. px[b][s][t] represents the
probability of extending the subsequences of length (s,t) by one in
the s direction, given the particular symbol, and py[b][s][t]
represents the probability of extending the subsequences of length
(s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert
lm
.
ndim
==
3
assert
am
.
ndim
==
3
assert
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
assert
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
# all of the logprobs here are properly normalized. We test that
# this code is invariant to adding constants in the appropriate ways.
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
am_max
,
_
=
torch
.
max
(
am
,
dim
=
2
,
keepdim
=
True
)
# am_max: [B][T][1]
lm_max
,
_
=
torch
.
max
(
lm
,
dim
=
2
,
keepdim
=
True
)
# lm_max: [B][S+1][1]
am_probs
=
(
am
-
am_max
).
exp
()
# [B][T][C]
lm_probs
=
(
lm
-
lm_max
).
exp
()
# [B][S+1][C]
# normalizers: [B][S+1][T]
normalizers
=
(
torch
.
matmul
(
lm_probs
,
am_probs
.
transpose
(
1
,
2
))
+
torch
.
finfo
(
lm_probs
.
dtype
).
tiny
).
log
()
# normalizer per frame, if we take only the LM probs by themselves
lmonly_normalizers
=
lm_probs
.
sum
(
dim
=
2
,
keepdim
=
True
)
# lmonly_normalizers: [B][S+1][1]
unigram_lm
=
(
torch
.
mean
(
lm_probs
/
lmonly_normalizers
,
dim
=
(
0
,
1
),
keepdim
=
True
)
+
torch
.
finfo
(
lm_probs
.
dtype
).
tiny
)
# [1][1][C]
amonly_normalizers
=
(
torch
.
mv
(
am_probs
.
reshape
(
-
1
,
C
),
unigram_lm
.
reshape
(
C
))
.
reshape
(
B
,
T
,
1
)
.
log
()
+
am_max
)
# [B][T][1]
amonly_normalizers
=
amonly_normalizers
.
transpose
(
1
,
2
)
# [B][1][T]
unigram_lm
=
unigram_lm
.
log
()
lmonly_normalizers
=
(
lmonly_normalizers
.
log
()
+
lm_max
)
# [B][S+1][1], log-normalizer, used for LM-only part of prob.
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers
=
normalizers
+
lm_max
+
am_max
.
transpose
(
1
,
2
)
# [B][S+1][T]
# px is the probs of the actual symbols (not yet normalized)..
px_am
=
torch
.
gather
(
am
.
unsqueeze
(
1
).
expand
(
B
,
S
,
T
,
C
),
dim
=
3
,
index
=
symbols
.
reshape
(
B
,
S
,
1
,
1
).
expand
(
B
,
S
,
T
,
1
),
).
squeeze
(
-
1
)
# [B][S][T]
if
not
modified
:
px_am
=
torch
.
cat
(
(
px_am
,
torch
.
full
(
(
B
,
S
,
1
),
float
(
"-inf"
),
device
=
px_am
.
device
,
dtype
=
px_am
.
dtype
,
),
),
dim
=
2
,
)
# now: [B][S][T+1], index [:,:,T] has -inf..
px_lm
=
torch
.
gather
(
lm
[:,
:
S
],
dim
=
2
,
index
=
symbols
.
unsqueeze
(
-
1
)
)
# [B][S][1]
px_lm_unigram
=
torch
.
gather
(
unigram_lm
.
expand
(
B
,
S
,
C
),
dim
=
2
,
index
=
symbols
.
unsqueeze
(
-
1
)
)
# [B][S][1]
px
=
px_am
+
px_lm
# [B][S][T+1] if not modified, [B][S][T] if modified
px
[:,
:,
:
T
]
-=
normalizers
[:,
:
S
,
:]
# px: [B][S][T+1] or [B][S][T]
px_amonly
=
(
px_am
+
px_lm_unigram
)
# [B][S][T+1] if !modified; [B][S][T] if modified.
px_amonly
[:,
:,
:
T
]
-=
amonly_normalizers
px_lmonly
=
px_lm
-
lmonly_normalizers
[:,
:
S
,
:]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am
=
am
[:,
:,
termination_symbol
].
unsqueeze
(
1
)
# [B][1][T]
py_lm
=
lm
[:,
:,
termination_symbol
].
unsqueeze
(
2
)
# [B][S+1][1]
py
=
py_am
+
py_lm
-
normalizers
py_lm_unigram
=
unigram_lm
[
0
][
0
][
termination_symbol
]
# scalar, normalized..
py_amonly
=
py_am
+
py_lm_unigram
-
amonly_normalizers
# [B][S+1][T]
py_lmonly
=
py_lm
-
lmonly_normalizers
# [B][S+1][T]
combined_scale
=
1.0
-
lm_only_scale
-
am_only_scale
# We need to avoid exact zeros in the scales because otherwise multiplying
# -inf by zero generates nan.
if
lm_only_scale
==
0.0
:
lm_only_scale
=
1.0e-20
if
am_only_scale
==
0.0
:
am_only_scale
=
1.0e-20
px_interp
=
(
px
*
combined_scale
+
px_lmonly
*
lm_only_scale
+
px_amonly
*
am_only_scale
)
py_interp
=
(
py
*
combined_scale
+
py_lmonly
*
lm_only_scale
+
py_amonly
*
am_only_scale
)
if
not
modified
:
px_interp
=
fix_for_boundary
(
px_interp
,
boundary
)
return
(
px_interp
,
py_interp
)
def
rnnt_loss_smoothed
(
lm
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
lm_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
reduction
:
Optional
[
str
]
=
"mean"
,
return_grad
:
bool
=
False
,
)
->
Union
[
Tuple
[
Tensor
,
Tuple
[
Tensor
,
Tensor
]],
Tensor
]:
"""A simple case of the RNN-T loss, where the 'joiner' network is just
addition.
Args:
lm:
language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes.
These are assumed to be well-normalized, in the sense that we could
use them as probabilities separately from the am scores
am:
acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols:
the symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
lm_only_scale:
the scale on the "LM-only" part of the loss.
am_only_scale:
the scale on the "AM-only" part of the loss, for which we use
an "averaged" LM (averaged over all histories, so effectively unigram).
boundary:
a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
return_grad:
Whether to return grads of px and py, this grad standing for the
occupation probability is the output of the backward with a
`fake gradient`, the `fake gradient` is the same as the gradient you'd
get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the
loss here is the loss with reduction "none".
This is useful to implement the pruned version of rnnt loss.
Returns:
If return_grad is False, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch if reduction equals
to "none", otherwise a scalar with the reduction applied.
If return_grad is True, the grads of px and py, which is the output of
backward with a `fake gradient`(see above), will be returned too. And the
returned value will be a tuple like (loss, (px_grad, py_grad)).
"""
px
,
py
=
get_rnnt_logprobs_smoothed
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
lm_only_scale
=
lm_only_scale
,
am_only_scale
=
am_only_scale
,
boundary
=
boundary
,
modified
=
modified
,
)
scores_and_grads
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
,
return_grad
=
return_grad
)
negated_loss
=
scores_and_grads
[
0
]
if
return_grad
else
scores_and_grads
if
reduction
==
"none"
:
loss
=
-
negated_loss
elif
reduction
==
"mean"
:
loss
=
-
torch
.
mean
(
negated_loss
)
elif
reduction
==
"sum"
:
loss
=
-
torch
.
sum
(
negated_loss
)
else
:
assert
(
False
),
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
return
(
loss
,
scores_and_grads
[
1
])
if
return_grad
else
loss
fast_rnnt/python/tests/CMakeLists.txt
0 → 100644
View file @
d53e923b
function
(
fast_rnnt_add_py_test source
)
get_filename_component
(
name
${
source
}
NAME_WE
)
set
(
name
"
${
name
}
_py"
)
add_test
(
NAME
${
name
}
COMMAND
"
${
PYTHON_EXECUTABLE
}
"
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
source
}
"
)
get_filename_component
(
fast_rnnt_path
${
CMAKE_CURRENT_LIST_DIR
}
DIRECTORY
)
set_property
(
TEST
${
name
}
PROPERTY ENVIRONMENT
"PYTHONPATH=
${
fast_rnnt_path
}
:$<TARGET_FILE_DIR:_fast_rnnt>:$ENV{PYTHONPATH}"
)
endfunction
()
# please sort the files in alphabetic order
set
(
py_test_files
mutual_information_test.py
rnnt_loss_test.py
)
foreach
(
source IN LISTS py_test_files
)
fast_rnnt_add_py_test
(
${
source
}
)
endforeach
()
fast_rnnt/python/tests/mutual_information_test.py
0 → 100644
View file @
d53e923b
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (authors: Daniel Povey,
# Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run this single test, use
#
# ctest --verbose -R mutual_information_test_py
import
random
import
unittest
import
fast_rnnt
import
torch
# Caution: this will fail occasionally due to cutoffs not being quite large
# enough. As long as it passes most of the time, it's OK.
class
TestMutualInformation
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
devices
=
[
torch
.
device
(
"cpu"
)]
if
torch
.
cuda
.
is_available
():
cls
.
devices
.
append
(
torch
.
device
(
"cuda"
,
0
))
if
torch
.
cuda
.
device_count
()
>
1
:
torch
.
cuda
.
set_device
(
1
)
cls
.
devices
.
append
(
torch
.
device
(
"cuda"
,
1
))
cls
.
dtypes
=
[
torch
.
float32
,
torch
.
float64
]
def
test_mutual_information_basic
(
self
):
for
_iter
in
range
(
100
):
(
B
,
S
,
T
)
=
(
random
.
randint
(
1
,
10
),
random
.
randint
(
1
,
16
),
random
.
randint
(
1
,
500
),
)
random_px
=
random
.
random
()
<
0.2
random_py
=
random
.
random
()
<
0.2
random_boundary
=
random
.
random
()
<
0.7
big_px
=
random
.
random
()
<
0.2
big_py
=
random
.
random
()
<
0.2
modified
=
random
.
random
()
<
0.5
if
modified
and
T
<
S
:
T
=
S
+
random
.
randint
(
0
,
30
)
for
dtype
in
self
.
dtypes
:
for
device
in
self
.
devices
:
if
random_boundary
:
def
get_boundary_row
():
this_S
=
random
.
randint
(
0
,
S
)
# allow empty sequence
this_T
=
random
.
randint
(
this_S
if
modified
else
1
,
T
)
s_begin
=
random
.
randint
(
0
,
S
-
this_S
)
t_begin
=
random
.
randint
(
0
,
T
-
this_T
)
s_end
=
s_begin
+
this_S
t_end
=
t_begin
+
this_T
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
if
device
==
torch
.
device
(
"cpu"
):
boundary
=
torch
.
tensor
(
[
get_boundary_row
()
for
_
in
range
(
B
)],
dtype
=
torch
.
int64
,
device
=
device
,
)
else
:
boundary
=
boundary
.
to
(
device
)
else
:
# Use default boundary, but either specified directly
# or not.
if
random
.
random
()
<
0.5
:
boundary
=
(
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
)
.
unsqueeze
(
0
)
.
expand
(
B
,
4
)
.
to
(
device
)
)
else
:
boundary
=
None
if
device
==
torch
.
device
(
"cpu"
):
if
random_px
:
# log of an odds ratio
px
=
torch
.
randn
(
B
,
S
,
T
+
(
0
if
modified
else
1
),
dtype
=
dtype
).
to
(
device
)
if
S
>
1
and
not
random_boundary
and
not
modified
:
px
[:,
:,
-
1
:]
=
float
(
"-inf"
)
else
:
# log of an odds ratio
px
=
torch
.
zeros
(
B
,
S
,
T
+
(
0
if
modified
else
1
),
dtype
=
dtype
).
to
(
device
)
# px and py get exponentiated, and then multiplied
# together up to 32 times (BLOCK_SIZE in the CUDA code),
# so 15 is actually a big number that could lead to
# overflow.
if
big_px
:
px
+=
15.0
if
random_py
:
# log of an odds ratio
py
=
torch
.
randn
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
else
:
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
if
big_py
:
py
+=
15.0
else
:
px
=
px
.
to
(
device
).
detach
()
py
=
py
.
to
(
device
).
detach
()
px
.
requires_grad
=
True
py
.
requires_grad
=
True
m
=
fast_rnnt
.
mutual_information_recursion
(
px
,
py
,
boundary
)
m2
=
fast_rnnt
.
joint_mutual_information_recursion
(
(
px
,),
(
py
,),
boundary
)
m3
=
fast_rnnt
.
joint_mutual_information_recursion
(
(
px
*
0.5
,
px
*
0.5
),
(
py
*
0.5
,
py
*
0.5
),
boundary
)
# it is supposed to be identical only after
# summing over dim 0, corresponding to the
# sequence dim
m3
=
m3
.
sum
(
dim
=
0
)
assert
torch
.
allclose
(
m
,
m2
)
assert
torch
.
allclose
(
m
,
m3
)
# the loop this is in checks that the CPU and CUDA versions
# give the same derivative;
# by randomizing which of m, m2 or m3 we backprop, we also
# ensure that the joint version of the code gives the same
# derivative as the regular version
scale
=
3
if
random
.
random
()
<
0.5
:
(
m
.
sum
()
*
scale
).
backward
()
elif
random
.
random
()
<
0.5
:
(
m2
.
sum
()
*
scale
).
backward
()
else
:
(
m3
.
sum
()
*
scale
).
backward
()
if
device
==
torch
.
device
(
"cpu"
):
expected_px_grad
=
px
.
grad
expected_py_grad
=
py
.
grad
expected_m
=
m
assert
torch
.
allclose
(
px
.
grad
,
expected_px_grad
.
to
(
device
),
atol
=
1.0e-02
,
rtol
=
1.0e-02
,
)
assert
torch
.
allclose
(
py
.
grad
,
expected_py_grad
.
to
(
device
),
atol
=
1.0e-02
,
rtol
=
1.0e-02
,
)
assert
torch
.
allclose
(
m
,
expected_m
.
to
(
device
),
atol
=
1.0e-02
,
rtol
=
1.0e-02
)
def
test_mutual_information_deriv
(
self
):
for
_iter
in
range
(
100
):
(
B
,
S
,
T
)
=
(
random
.
randint
(
1
,
100
),
random
.
randint
(
1
,
200
),
random
.
randint
(
1
,
200
),
)
random_px
=
random
.
random
()
<
0.2
random_py
=
random
.
random
()
<
0.2
random_boundary
=
random
.
random
()
<
0.7
big_px
=
random
.
random
()
<
0.2
big_py
=
random
.
random
()
<
0.2
modified
=
random
.
random
()
<
0.5
if
modified
and
T
<
S
:
T
=
S
+
random
.
randint
(
0
,
30
)
for
dtype
in
self
.
dtypes
:
for
device
in
self
.
devices
:
if
random_boundary
:
def
get_boundary_row
():
this_S
=
random
.
randint
(
1
,
S
)
this_T
=
random
.
randint
(
this_S
if
modified
else
1
,
T
)
s_begin
=
random
.
randint
(
0
,
S
-
this_S
)
t_begin
=
random
.
randint
(
0
,
T
-
this_T
)
s_end
=
s_begin
+
this_S
t_end
=
t_begin
+
this_T
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
if
device
==
torch
.
device
(
"cpu"
):
boundary
=
torch
.
tensor
(
[
get_boundary_row
()
for
_
in
range
(
B
)],
dtype
=
torch
.
int64
,
device
=
device
,
)
else
:
boundary
=
boundary
.
to
(
device
)
else
:
# Use default boundary, but either specified directly
# or not.
if
random
.
random
()
<
0.5
:
boundary
=
(
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
)
.
unsqueeze
(
0
)
.
expand
(
B
,
4
)
.
to
(
device
)
)
else
:
boundary
=
None
T1
=
T
+
(
0
if
modified
else
1
)
if
device
==
torch
.
device
(
"cpu"
):
if
random_px
:
# log of an odds ratio
px
=
torch
.
randn
(
B
,
S
,
T1
,
dtype
=
dtype
).
to
(
device
)
else
:
# log of an odds ratio
px
=
torch
.
zeros
(
B
,
S
,
T1
,
dtype
=
dtype
).
to
(
device
)
# px and py get exponentiated, and then multiplied
# together up to 32 times (BLOCK_SIZE in the CUDA code),
# so 15 is actually a big number that could lead to
# overflow.
if
big_px
:
px
+=
15.0
if
random_py
:
# log of an odds ratio
py
=
torch
.
randn
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
else
:
# log of an odds ratio
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
if
big_py
:
py
+=
15.0
else
:
px
=
px
.
to
(
device
).
detach
()
py
=
py
.
to
(
device
).
detach
()
px
.
requires_grad
=
True
py
.
requires_grad
=
True
m
=
fast_rnnt
.
mutual_information_recursion
(
px
,
py
,
boundary
)
m_grad
=
torch
.
randn
(
B
,
dtype
=
dtype
,
device
=
device
)
m
.
backward
(
gradient
=
m_grad
)
delta
=
1.0e-04
delta_px
=
delta
*
torch
.
randn_like
(
px
)
m2
=
fast_rnnt
.
mutual_information_recursion
(
px
+
delta_px
,
py
,
boundary
)
delta_m
=
m2
-
m
observed_delta
=
(
delta_m
*
m_grad
).
sum
().
to
(
"cpu"
)
predicted_delta
=
(
delta_px
*
px
.
grad
).
sum
().
to
(
"cpu"
)
atol
=
1.0e-02
if
dtype
==
torch
.
float32
else
1.0e-04
rtol
=
1.0e-02
if
dtype
==
torch
.
float32
else
1.0e-04
assert
torch
.
allclose
(
observed_delta
,
predicted_delta
,
atol
=
atol
,
rtol
=
rtol
)
delta_py
=
delta
*
torch
.
randn_like
(
py
)
m2
=
fast_rnnt
.
mutual_information_recursion
(
px
,
py
+
delta_py
,
boundary
)
delta_m
=
m2
-
m
observed_delta
=
(
delta_m
*
m_grad
).
sum
().
to
(
"cpu"
)
predicted_delta
=
(
delta_py
*
py
.
grad
).
sum
().
to
(
"cpu"
)
assert
torch
.
allclose
(
observed_delta
,
predicted_delta
,
atol
=
atol
,
rtol
=
rtol
)
if
__name__
==
"__main__"
:
unittest
.
main
()
fast_rnnt/python/tests/rnnt_loss_test.py
0 → 100644
View file @
d53e923b
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (authors: Daniel Povey,
# Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run this single test, use
#
# ctest --verbose -R rnnt_loss_test_py
import
unittest
import
fast_rnnt
import
random
import
torch
class
TestRnntLoss
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
devices
=
[
torch
.
device
(
"cpu"
)]
if
torch
.
cuda
.
is_available
():
cls
.
devices
.
append
(
torch
.
device
(
"cuda"
,
0
))
if
torch
.
cuda
.
device_count
()
>
1
:
torch
.
cuda
.
set_device
(
1
)
cls
.
devices
.
append
(
torch
.
device
(
"cuda"
,
1
))
try
:
import
torchaudio
import
torchaudio.functional
if
hasattr
(
torchaudio
.
functional
,
"rnnt_loss"
):
cls
.
has_torch_rnnt_loss
=
True
else
:
cls
.
has_torch_rnnt_loss
=
False
print
(
f
"Current torchaudio version:
{
torchaudio
.
__version__
}
\n
"
"Skipping the tests of comparing rnnt loss with torch "
"one, to enable these tests please install a "
"version >= 0.10.0"
)
except
ImportError
as
e
:
cls
.
has_torch_rnnt_loss
=
False
print
(
f
"Import torchaudio error, error message:
{
e
}
\n
"
"Skipping the tests of comparing rnnt loss with torch "
"one, to enable these tests, please install torchaudio "
"with version >= 0.10.0"
)
def
test_rnnt_loss_basic
(
self
):
B
=
1
S
=
3
T
=
4
# C = 3
for
device
in
self
.
devices
:
# lm: [B][S+1][C]
lm
=
torch
.
tensor
(
[[[
0
,
0
,
1
],
[
0
,
1
,
1
],
[
1
,
0
,
1
],
[
2
,
2
,
0
]]],
dtype
=
torch
.
float
,
device
=
device
,
)
# am: [B][T][C]
am
=
torch
.
tensor
(
[[[
0
,
1
,
2
],
[
0
,
0
,
0
],
[
0
,
2
,
4
],
[
0
,
3
,
3
]]],
dtype
=
torch
.
float
,
device
=
device
,
)
termination_symbol
=
2
symbols
=
torch
.
tensor
([[
0
,
1
,
0
]],
dtype
=
torch
.
long
,
device
=
device
)
px
,
py
=
fast_rnnt
.
get_rnnt_logprobs
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
)
assert
px
.
shape
==
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
m
=
fast_rnnt
.
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
None
)
if
device
==
torch
.
device
(
"cpu"
):
expected
=
-
m
assert
torch
.
allclose
(
-
m
,
expected
.
to
(
device
))
# test rnnt_loss_simple
m
=
fast_rnnt
.
rnnt_loss_simple
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
None
,
reduction
=
"none"
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
# test rnnt_loss_smoothed
m
=
fast_rnnt
.
rnnt_loss_smoothed
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
None
,
reduction
=
"none"
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
probs
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
# test rnnt_loss
m
=
fast_rnnt
.
rnnt_loss
(
logits
=
probs
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
None
,
reduction
=
"none"
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
# compare with torchaudio rnnt_loss
if
self
.
has_torch_rnnt_loss
:
import
torchaudio.functional
m
=
torchaudio
.
functional
.
rnnt_loss
(
logits
=
probs
,
targets
=
symbols
.
int
(),
logit_lengths
=
torch
.
tensor
(
[
T
]
*
B
,
dtype
=
torch
.
int32
,
device
=
device
),
target_lengths
=
torch
.
tensor
(
[
S
]
*
B
,
dtype
=
torch
.
int32
,
device
=
device
),
blank
=
termination_symbol
,
reduction
=
"none"
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
# should be invariant to adding a constant for any frame.
lm
+=
torch
.
randn
(
B
,
S
+
1
,
1
,
device
=
device
)
am
+=
torch
.
randn
(
B
,
T
,
1
,
device
=
device
)
m
=
fast_rnnt
.
rnnt_loss_simple
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
None
,
reduction
=
"none"
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
m
=
fast_rnnt
.
rnnt_loss_smoothed
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
None
,
reduction
=
"none"
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
probs
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
m
=
fast_rnnt
.
rnnt_loss
(
logits
=
probs
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
None
,
reduction
=
"none"
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
def
test_rnnt_loss_random
(
self
):
B
=
5
S
=
20
T
=
300
C
=
100
frames
=
torch
.
randint
(
S
,
T
,
(
B
,))
seq_length
=
torch
.
randint
(
3
,
S
-
1
,
(
B
,))
T
=
torch
.
max
(
frames
)
S
=
torch
.
max
(
seq_length
)
am_
=
torch
.
randn
((
B
,
T
,
C
),
dtype
=
torch
.
float32
)
lm_
=
torch
.
randn
((
B
,
S
+
1
,
C
),
dtype
=
torch
.
float32
)
symbols_
=
torch
.
randint
(
0
,
C
-
1
,
(
B
,
S
))
termination_symbol
=
C
-
1
boundary_
=
torch
.
zeros
((
B
,
4
),
dtype
=
torch
.
int64
)
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
for
modified
in
[
True
,
False
]:
for
device
in
self
.
devices
:
# lm: [B][S+1][C]
lm
=
lm_
.
to
(
device
)
# am: [B][T][C]
am
=
am_
.
to
(
device
)
symbols
=
symbols_
.
to
(
device
)
boundary
=
boundary_
.
to
(
device
)
px
,
py
=
fast_rnnt
.
get_rnnt_logprobs
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
modified
=
modified
,
)
assert
px
.
shape
==
(
B
,
S
,
T
)
if
modified
else
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
m
=
fast_rnnt
.
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
if
device
==
torch
.
device
(
"cpu"
):
expected
=
-
torch
.
mean
(
m
)
assert
torch
.
allclose
(
-
torch
.
mean
(
m
),
expected
.
to
(
device
))
m
=
fast_rnnt
.
rnnt_loss_simple
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
modified
=
modified
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
m
=
fast_rnnt
.
rnnt_loss_smoothed
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
boundary
,
modified
=
modified
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
probs
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
m
=
fast_rnnt
.
rnnt_loss
(
logits
=
probs
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
modified
=
modified
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
# compare with torchaudio rnnt_loss
if
self
.
has_torch_rnnt_loss
and
not
modified
:
import
torchaudio.functional
m
=
torchaudio
.
functional
.
rnnt_loss
(
logits
=
probs
,
targets
=
symbols
.
int
(),
logit_lengths
=
boundary
[:,
3
].
int
(),
target_lengths
=
boundary
[:,
2
].
int
(),
blank
=
termination_symbol
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
# should be invariant to adding a constant for any frame.
lm
+=
torch
.
randn
(
B
,
S
+
1
,
1
,
device
=
device
)
am
+=
torch
.
randn
(
B
,
T
,
1
,
device
=
device
)
m
=
fast_rnnt
.
rnnt_loss_simple
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
modified
=
modified
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
probs
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
m
=
fast_rnnt
.
rnnt_loss
(
logits
=
probs
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
modified
=
modified
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
m
=
fast_rnnt
.
rnnt_loss_smoothed
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
boundary
,
modified
=
modified
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
def
test_rnnt_loss_gradient
(
self
):
if
self
.
has_torch_rnnt_loss
:
import
torchaudio.functional
B
=
5
S
=
20
T
=
300
C
=
100
frames
=
torch
.
randint
(
S
,
T
,
(
B
,))
seq_length
=
torch
.
randint
(
3
,
S
-
1
,
(
B
,))
T
=
torch
.
max
(
frames
)
S
=
torch
.
max
(
seq_length
)
am_
=
torch
.
randn
((
B
,
T
,
C
),
dtype
=
torch
.
float32
)
lm_
=
torch
.
randn
((
B
,
S
+
1
,
C
),
dtype
=
torch
.
float32
)
symbols_
=
torch
.
randint
(
0
,
C
-
1
,
(
B
,
S
))
termination_symbol
=
C
-
1
boundary_
=
torch
.
zeros
((
B
,
4
),
dtype
=
torch
.
int64
)
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
for
device
in
self
.
devices
:
# lm: [B][S+1][C]
lm
=
lm_
.
to
(
device
)
# am: [B][T][C]
am
=
am_
.
to
(
device
)
symbols
=
symbols_
.
to
(
device
)
boundary
=
boundary_
.
to
(
device
)
logprobs
=
am
.
unsqueeze
(
2
)
+
lm
.
unsqueeze
(
1
)
logprobs
.
requires_grad_
()
k2_loss
=
fast_rnnt
.
rnnt_loss
(
logits
=
logprobs
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
)
k2_grad
=
torch
.
autograd
.
grad
(
k2_loss
,
logprobs
)
k2_grad
=
k2_grad
[
0
]
logprobs2
=
logprobs
.
detach
().
clone
().
float
()
logprobs2
.
requires_grad_
()
torch_loss
=
torchaudio
.
functional
.
rnnt_loss
(
logits
=
logprobs2
,
targets
=
symbols
.
int
(),
logit_lengths
=
boundary
[:,
3
].
int
(),
target_lengths
=
boundary
[:,
2
].
int
(),
blank
=
termination_symbol
,
)
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
logprobs2
)
torch_grad
=
torch_grad
[
0
]
assert
torch
.
allclose
(
k2_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
k2_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_rnnt_loss_smoothed
(
self
):
B
=
1
S
=
3
T
=
4
# C = 3
for
device
in
self
.
devices
:
# lm: [B][S+1][C]
lm
=
torch
.
tensor
(
[[[
0
,
0
,
1
],
[
0
,
1
,
1
],
[
1
,
0
,
1
],
[
2
,
2
,
0
]]],
dtype
=
torch
.
float
,
device
=
device
,
)
# am: [B][T][C]
am
=
torch
.
tensor
(
[[[
0
,
1
,
2
],
[
0
,
0
,
0
],
[
0
,
2
,
4
],
[
0
,
3
,
3
]]],
dtype
=
torch
.
float
,
device
=
device
,
)
termination_symbol
=
2
symbols
=
torch
.
tensor
([[
0
,
1
,
0
]],
dtype
=
torch
.
long
,
device
=
device
)
m
=
fast_rnnt
.
rnnt_loss_smoothed
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.333
,
boundary
=
None
,
)
if
device
==
torch
.
device
(
"cpu"
):
expected
=
m
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
# should be invariant to adding a constant for any frame.
lm
+=
torch
.
randn
(
B
,
S
+
1
,
1
,
device
=
device
)
am
+=
torch
.
randn
(
B
,
T
,
1
,
device
=
device
)
m
=
fast_rnnt
.
rnnt_loss_smoothed
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.333
,
boundary
=
None
,
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
def
test_rnnt_loss_pruned
(
self
):
B
=
4
T
=
300
S
=
50
C
=
10
frames
=
torch
.
randint
(
S
,
T
,
(
B
,))
seq_length
=
torch
.
randint
(
3
,
S
-
1
,
(
B
,))
T
=
torch
.
max
(
frames
)
S
=
torch
.
max
(
seq_length
)
am_
=
torch
.
randn
((
B
,
T
,
C
),
dtype
=
torch
.
float64
)
lm_
=
torch
.
randn
((
B
,
S
+
1
,
C
),
dtype
=
torch
.
float64
)
symbols_
=
torch
.
randint
(
0
,
C
-
1
,
(
B
,
S
))
terminal_symbol
=
C
-
1
boundary_
=
torch
.
zeros
((
B
,
4
),
dtype
=
torch
.
int64
)
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
for
modified
in
[
True
,
False
]:
for
device
in
self
.
devices
:
# normal rnnt
am
=
am_
.
to
(
device
)
lm
=
lm_
.
to
(
device
)
symbols
=
symbols_
.
to
(
device
)
boundary
=
boundary_
.
to
(
device
)
t_am
=
am
.
unsqueeze
(
2
).
float
()
t_lm
=
lm
.
unsqueeze
(
1
).
float
()
t_prob
=
t_am
+
t_lm
# nonlinear transform
t_prob
=
torch
.
sigmoid
(
t_prob
)
k2_loss
=
fast_rnnt
.
rnnt_loss
(
logits
=
t_prob
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
modified
=
modified
,
)
print
(
f
"unpruned rnnt loss with modified
{
modified
}
:
{
k2_loss
}
"
)
# pruning
k2_simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
lm
=
lm
,
am
=
am
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
modified
=
modified
,
return_grad
=
True
,
reduction
=
"none"
,
)
for
r
in
range
(
2
,
50
,
5
):
ranges
=
fast_rnnt
.
get_rnnt_prune_ranges
(
px_grad
=
px_grad
,
py_grad
=
py_grad
,
boundary
=
boundary
,
s_range
=
r
,
)
# (B, T, r, C)
am_p
,
lm_p
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
t_prob_p
=
am_p
+
lm_p
# nonlinear transform
t_prob_p
=
torch
.
sigmoid
(
t_prob_p
)
pruned_loss
=
fast_rnnt
.
rnnt_loss_pruned
(
logits
=
t_prob_p
,
symbols
=
symbols
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
modified
=
modified
,
reduction
=
"none"
,
)
print
(
f
"pruning loss with range
{
r
}
:
{
pruned_loss
}
"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/requirements_test.txt
deleted
100644 → 0
View file @
b5828e2b
torch>=1.5
tqdm
tests/test.py
deleted
100644 → 0
View file @
b5828e2b
import
os
import
random
import
time
import
unittest
import
torch
from
tqdm
import
tqdm
from
torch_discounted_cumsum
import
discounted_cumsum_left
,
discounted_cumsum_right
def
get_grad
(
param
,
out
):
out
.
sum
().
backward
()
grad
=
param
.
grad
.
clone
()
del
param
.
grad
return
grad
def
discounted_cumsum_left_gold
(
input
,
gamma
):
assert
input
.
dim
()
==
2
assert
0
<=
gamma
<=
1
out
=
[]
last_col
=
torch
.
zeros
((
input
.
shape
[
0
],
1
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
for
i
in
range
(
input
.
shape
[
1
]):
cur_col
=
input
[:,
i
].
unsqueeze
(
-
1
)
last_col
=
cur_col
+
gamma
*
last_col
out
.
append
(
last_col
)
out
=
torch
.
cat
(
out
,
dim
=
1
)
return
out
def
discounted_cumsum_right_gold
(
input
,
gamma
):
assert
input
.
dim
()
==
2
assert
0
<=
gamma
<=
1
out
=
[]
last_col
=
torch
.
zeros
((
input
.
shape
[
0
],
1
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
for
i
in
reversed
(
range
(
input
.
shape
[
1
])):
cur_col
=
input
[:,
i
].
unsqueeze
(
-
1
)
last_col
=
cur_col
+
gamma
*
last_col
out
.
insert
(
0
,
last_col
)
out
=
torch
.
cat
(
out
,
dim
=
1
)
return
out
def
discounted_cumsum_lib
(
x
,
gamma
,
dir
):
return
{
'left'
:
discounted_cumsum_left
,
'right'
:
discounted_cumsum_right
,
}[
dir
](
x
,
gamma
)
def
discounted_cumsum_gold
(
x
,
gamma
,
dir
):
return
{
'left'
:
discounted_cumsum_left_gold
,
'right'
:
discounted_cumsum_right_gold
,
}[
dir
](
x
,
gamma
)
def
compute_linf
(
batchsz
,
veclen
,
dir
,
gamma
=
0.99
,
dtype
=
torch
.
float32
,
cuda
=
False
,
data
=
'randn'
,
tol
=
1e-3
,
seed
=
2021
):
torch
.
manual_seed
(
seed
)
if
data
==
'randn'
:
x
=
torch
.
randn
((
batchsz
,
veclen
),
dtype
=
dtype
)
elif
data
==
'ones'
:
x
=
torch
.
ones
((
batchsz
,
veclen
),
dtype
=
dtype
)
else
:
raise
ValueError
(
'Invalid data generation identifier'
)
if
cuda
:
x
=
x
.
cuda
()
x
=
torch
.
nn
.
Parameter
(
x
)
out_gold
=
discounted_cumsum_gold
(
x
,
gamma
,
dir
)
grad_gold
=
get_grad
(
x
,
out_gold
)
out_lib
=
discounted_cumsum_lib
(
x
,
gamma
,
dir
)
grad_lib
=
get_grad
(
x
,
out_lib
)
out_linf
=
(
out_lib
-
out_gold
).
abs
().
max
().
item
()
grad_linf
=
(
grad_lib
-
grad_gold
).
abs
().
max
().
item
()
if
out_linf
>=
tol
or
grad_linf
>=
tol
:
print
(
f
'x=
{
x
}
\n
out_gold=
{
out_gold
}
\n
out_lib=
{
out_lib
}
\n
grad_gold=
{
grad_gold
}
\n
grad_lib=
{
grad_lib
}
\n
'
)
return
out_linf
,
grad_linf
class
TestDiscountedCumSum
(
unittest
.
TestCase
):
def
test_validity
(
self
):
print
(
'Testing validity...'
)
is_cuda
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
''
)
!=
''
for
cuda
in
(
True
,
False
):
if
cuda
and
not
is_cuda
:
print
(
'Skipping validity CUDA tests'
)
continue
rng
=
random
.
Random
(
2021
)
with
tqdm
(
total
=
2
*
2
*
2
*
17
)
as
pbar
:
for
data
in
(
'ones'
,
'randn'
):
for
dtype
in
(
torch
.
float32
,
torch
.
float64
):
for
i
in
range
(
2
):
batchsz
=
8
**
i
for
j
in
range
(
17
):
veclen
=
max
(
1
,
2
**
j
+
rng
.
randint
(
-
1
,
1
))
gamma
=
rng
.
random
()
seed
=
rng
.
randint
(
0
,
2
**
16
)
dir
=
rng
.
choice
([
'left'
,
'right'
])
tol
=
2e-3
out_linf
,
grad_linf
=
compute_linf
(
batchsz
,
veclen
,
dir
,
gamma
,
dtype
,
cuda
,
data
,
tol
,
seed
)
msg
=
f
'Validity test failed with batchsz=
{
batchsz
}
, veclen=
{
veclen
}
, dir=
{
dir
}
, '
\
f
'gamma=
{
gamma
}
, dtype=
{
dtype
}
, cuda=
{
cuda
}
, data=
{
data
}
, seed=
{
seed
}
, '
\
f
'out_linf=
{
out_linf
}
, grad_linf=
{
grad_linf
}
'
self
.
assertLess
(
out_linf
,
tol
,
msg
)
self
.
assertLess
(
grad_linf
,
tol
,
msg
)
pbar
.
update
(
1
)
def
test_precision
(
self
):
print
(
'Testing precision...'
)
is_cuda
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
''
)
!=
''
if
not
is_cuda
:
print
(
'Skipping precision tests'
)
return
batchsz
=
1
veclen
=
10000
gamma
=
0.99
dir
=
'right'
for
data
in
(
'ones'
,
'randn'
):
if
data
==
'ones'
:
precision_factor
=
2.0
else
:
precision_factor
=
1.1
torch
.
manual_seed
(
2021
)
if
data
==
'randn'
:
x_32
=
torch
.
randn
((
batchsz
,
veclen
),
dtype
=
torch
.
float32
)
elif
data
==
'ones'
:
x_32
=
torch
.
ones
((
batchsz
,
veclen
),
dtype
=
torch
.
float32
)
else
:
raise
ValueError
(
'Invalid data generation identifier'
)
x_32
=
x_32
.
cuda
()
x_64
=
x_32
.
double
()
gold_64
=
discounted_cumsum_gold
(
x_64
,
gamma
,
dir
)
gold_32
=
discounted_cumsum_gold
(
x_32
,
gamma
,
dir
).
double
()
lib_32
=
discounted_cumsum_lib
(
x_32
,
gamma
,
dir
).
double
()
err_32_gold
=
(
gold_32
-
gold_64
).
abs
().
max
().
item
()
err_32_lib
=
(
lib_32
-
gold_64
).
abs
().
max
().
item
()
msg
=
f
'Precision improvement test failed with data=
{
data
}
, '
\
f
'err_32_gold=
{
err_32_gold
}
, err_32_lib=
{
err_32_lib
}
'
self
.
assertLess
(
precision_factor
*
err_32_lib
,
err_32_gold
,
msg
)
print
(
f
'data=
{
data
}
\n
err_32_gold=
{
err_32_gold
:
10.8
f
}
\n
err_32_lib =
{
err_32_lib
:
10.8
f
}
'
)
def
test_speed
(
self
):
print
(
'Testing speed...'
)
is_cuda
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
''
)
!=
''
NUM_RUNS
=
30
NUM_RUNS_GOLD
=
6
if
not
is_cuda
:
print
(
'Skipping speed tests'
)
return
gamma
=
0.99
x_32
=
torch
.
randn
((
1
,
100000
),
dtype
=
torch
.
float32
)
x_32
+=
torch
.
ones_like
(
x_32
)
x_32_gpu
=
x_32
.
cuda
()
timer
=
time
.
clock_gettime
(
time
.
CLOCK_MONOTONIC
)
for
_
in
tqdm
(
range
(
NUM_RUNS_GOLD
),
desc
=
'gold'
,
leave
=
True
):
discounted_cumsum_right_gold
(
x_32
,
gamma
)
dur_gold
=
time
.
clock_gettime
(
time
.
CLOCK_MONOTONIC
)
-
timer
dur_gold
=
dur_gold
*
NUM_RUNS
/
NUM_RUNS_GOLD
timer
=
time
.
clock_gettime
(
time
.
CLOCK_MONOTONIC
)
for
_
in
tqdm
(
range
(
NUM_RUNS
),
desc
=
'lib_cpu'
,
leave
=
True
):
discounted_cumsum_right
(
x_32
,
gamma
)
dur_lib_cpu
=
time
.
clock_gettime
(
time
.
CLOCK_MONOTONIC
)
-
timer
timer
=
time
.
clock_gettime
(
time
.
CLOCK_MONOTONIC
)
for
_
in
tqdm
(
range
(
NUM_RUNS
),
desc
=
'lib_cuda'
,
leave
=
True
):
discounted_cumsum_right
(
x_32_gpu
,
gamma
)
dur_lib_cuda
=
time
.
clock_gettime
(
time
.
CLOCK_MONOTONIC
)
-
timer
print
(
f
'dur_gold:
{
dur_gold
:
7.4
f
}
sec'
)
print
(
f
'dur_lib_cpu:
{
dur_lib_cpu
:
7.4
f
}
sec'
)
print
(
f
'dur_lib_cuda:
{
dur_lib_cuda
:
7.4
f
}
sec'
)
print
(
f
'speedup gold -> lib_cpu:
{
dur_gold
/
dur_lib_cpu
:
5.2
f
}
'
)
print
(
f
'speedup gold -> lib_cuda:
{
dur_gold
/
dur_lib_cuda
:
5.2
f
}
'
)
print
(
f
'speedup lib_cpu -> lib_cuda:
{
dur_lib_cpu
/
dur_lib_cuda
:
5.2
f
}
'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
torch_mutual_information/__init__.py
deleted
100644 → 0
View file @
b5828e2b
from
.mutual_information
import
mutual_information_recursion
,
joint_mutual_information_recursion
from
.rnnt
import
get_rnnt_logprobs
,
rnnt_loss_simple
,
rnnt_loss_aux
torch_mutual_information/mutual_information.py
deleted
100644 → 0
View file @
b5828e2b
import
os
import
torch
from
torch
import
Tensor
from
typing
import
Tuple
,
Optional
,
Sequence
from
torch.utils.cpp_extension
import
load
VERBOSE
=
False
def
_resolve
(
name
):
return
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
name
)
try
:
import
torch_mutual_information_cpu
except
ImportError
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_mutual_information_cpu'
)
torch_mutual_information_cpu
=
load
(
name
=
'torch_mutual_information_cpu'
,
sources
=
[
_resolve
(
'mutual_information_cpu.cpp'
),
],
verbose
=
VERBOSE
,
)
try
:
import
torch_mutual_information_cuda
except
ImportError
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_mutual_information_cuda'
)
torch_mutual_information_cuda
=
None
if
torch
.
cuda
.
is_available
():
torch_mutual_information_cuda
=
load
(
name
=
'torch_mutual_information_cuda'
,
sources
=
[
_resolve
(
'mutual_information_cuda.cpp'
),
_resolve
(
'mutual_information_cuda_kernel.cu'
),
],
verbose
=
VERBOSE
,
)
def
_mutual_information_forward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundary
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
px
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_mutual_information_cuda
.
mutual_information_cuda
(
px
,
py
,
boundary
,
p
)
else
:
return
torch_mutual_information_cpu
.
mutual_information_cpu
(
px
,
py
,
boundary
,
p
)
def
_mutual_information_backward_dispatcher
(
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundary
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
ans_grad
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
px
.
is_cuda
:
if
torch_mutual_information_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
overwrite_ans_grad
=
True
if
overwrite_ans_grad
:
ans_grad_copy
=
ans_grad
.
clone
()
ans
=
tuple
(
torch_mutual_information_cuda
.
mutual_information_backward_cuda
(
px
,
py
,
boundary
,
p
,
ans_grad_copy
,
overwrite_ans_grad
))
if
overwrite_ans_grad
:
if
not
torch
.
allclose
(
ans_grad
,
ans_grad_copy
,
rtol
=
1.0e-02
):
print
(
f
"Warning: possible excesssive roundoff in mutual information backward "
f
"recursion:
{
ans_grad
}
vs.
{
ans_grad_copy
}
"
);
return
ans
else
:
return
tuple
(
torch_mutual_information_cpu
.
mutual_information_backward_cpu
(
px
,
py
,
boundary
,
p
,
ans_grad
))
class
MutualInformationRecursionFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
px
:
torch
.
Tensor
,
py
:
torch
.
Tensor
,
boundary
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
(
B
,
S
,
T1
)
=
px
.
shape
T
=
T1
-
1
;
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
if
boundary
is
not
None
:
assert
boundary
.
shape
==
(
B
,
4
)
else
:
boundary
=
torch
.
zeros
(
0
,
0
,
dtype
=
torch
.
int64
,
device
=
px
.
device
)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px
.
device
,
dtype
=
px
.
dtype
)
ans
=
_mutual_information_forward_dispatcher
(
px
,
py
,
boundary
,
p
)
# print(f"p = {p}, boundary = {boundary}, psum={p.sum()}")
if
px
.
requires_grad
or
py
.
requires_grad
:
ctx
.
save_for_backward
(
px
,
py
,
boundary
,
p
)
return
ans
@
staticmethod
def
backward
(
ctx
,
ans_grad
:
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
(
px
,
py
,
boundary
,
p
)
=
ctx
.
saved_tensors
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px
,
py
,
boundary
,
p
,
ans_grad
)
return
(
px_grad
,
py_grad
,
None
)
def
mutual_information_recursion
(
px
,
py
,
boundary
=
None
):
"""A recursion that is useful in computing mutual information between two sequences of
real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
the arguments are definitions that would be used when computing this type of
mutual information, but you can also view them as arbitrary quantities and just
make use of the formula computed by this function.
Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T+1],
where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of
EOS symbols but not BOS symbols). In the mutual information application,
px[b][s][t] would represent the following log odds ratio; ignoring
the b index on the right to make the notation more compact,
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an x value as opposed to a y value. In
practice it might be computed as a + b, where a is the log
probability of choosing to extend the sequence of length (s,t)
with an x as opposed to a y value; and b might in practice be
of the form:
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where N is the number of terms that the sum over t' included, which
might include some or all of the other sequences as well as this one.
Note: we don't require px and py to be contiguous, but the
code assumes for optimization purposes that the T axis has
stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S+1][T],
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference
is that for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous.
boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T (this implies that empty sequences are allowed). If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutual
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S+1,T+1]),
representing a mutual information between sub-sequences of lengths s and t:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
where we handle edge cases by treating quantities with negative indexes
as -infinity. The extension to cases where the boundaries are specified
should be obvious; it just works on shorter sequences with offsets into
px and py.
"""
assert
px
.
ndim
==
3
B
,
S
,
T1
=
px
.
shape
T
=
T1
-
1
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
px
.
dtype
==
py
.
dtype
(
B
,
S
,
T
)
=
px
.
shape
if
boundary
is
not
None
:
assert
boundary
.
dtype
==
torch
.
int64
assert
boundary
.
shape
==
(
B
,
4
)
for
[
s_begin
,
t_begin
,
s_end
,
t_end
]
in
boundary
.
to
(
'cpu'
).
tolist
():
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
# The following assertions are for efficiency
assert
px
.
stride
()[
-
1
]
==
1
assert
py
.
stride
()[
-
1
]
==
1
return
MutualInformationRecursionFunction
.
apply
(
px
,
py
,
boundary
)
def
_inner
(
a
:
Tensor
,
b
:
Tensor
)
->
Tensor
:
"""
Does inner product on the last dimension, with expected broadcasting, i.e. equivalent to
(a * b).sum(dim=-1)
without creating a large temporary.
"""
assert
a
.
shape
[
-
1
]
==
b
.
shape
[
-
1
]
# last last dim be K
a
=
a
.
unsqueeze
(
-
2
)
# (..., 1, K)
b
=
b
.
unsqueeze
(
-
1
)
# (..., K, 1)
c
=
torch
.
matmul
(
a
,
b
)
# (..., 1, 1)
return
c
.
squeeze
(
-
1
).
squeeze
(
-
1
)
def
joint_mutual_information_recursion
(
px
:
Sequence
[
Tensor
],
py
:
Sequence
[
Tensor
],
boundary
:
Optional
[
Tensor
]
=
None
)
->
Sequence
[
Tensor
]:
"""A recursion that is useful for modifications of RNN-T and similar loss functions,
where the recursion probabilities have a number of terms and you want them reported
separately. See mutual_information_recursion() for more documentation of the
basic aspects of this.
Args:
px: a sequence of Tensors, each of the same shape [B][S][T+1]
py: a sequence of Tensor, each of the same shape [B][S+1][T], the sequence must be
the same length as px.
boundary: optionally, a LongTensor of shape [B][4] containing rows
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
Returns:
a Tensor of shape (len(px), B),
whose sum over dim 0 is the total log-prob of the recursion mentioned below, per sequence.
The first element of the sequence of length len(px) is "special", in that it has an offset term
reflecting the difference between sum-of-log and log-of-sum; for more interpretable
loss values, the "main" part of your loss function should be first.
The recursion below applies if boundary == None, when it defaults
to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px and py:
p = tensor of shape (B, S+1, T+1), containing -infinity
p[b,0,0] = 0.0
# do the following in loop over s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
p[b,s,t-1] + py_sum[b,s,t-1])
(if s > 0 or t > 0)
return b[:][S][T]
This function lets you implement the above recursion efficiently, except
that it gives you a breakdown of the contribution from all the elements of
px and py separately. As noted above, the first element of the
sequence is "special".
"""
N
=
len
(
px
)
assert
len
(
py
)
==
N
and
N
>
0
B
,
S
,
T1
=
px
[
0
].
shape
T
=
T1
-
1
assert
py
[
0
].
shape
==
(
B
,
S
+
1
,
T
)
assert
px
[
0
].
dtype
==
py
[
0
].
dtype
px_cat
=
torch
.
stack
(
px
,
dim
=
0
)
# (N, B, S, T+1)
py_cat
=
torch
.
stack
(
py
,
dim
=
0
)
# (N, B, S+1, T)
px_tot
=
px_cat
.
sum
(
dim
=
0
)
# (B, S, T+1)
py_tot
=
py_cat
.
sum
(
dim
=
0
)
# (B, S+1, T)
if
boundary
is
not
None
:
assert
boundary
.
dtype
==
torch
.
int64
assert
boundary
.
shape
==
(
B
,
4
)
for
[
s_begin
,
t_begin
,
s_end
,
t_end
]
in
boundary
.
to
(
'cpu'
).
tolist
():
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
else
:
boundary
=
torch
.
zeros
(
0
,
0
,
dtype
=
torch
.
int64
,
device
=
px_tot
.
device
)
px_tot
,
py_tot
=
px_tot
.
contiguous
(),
py_tot
.
contiguous
()
# The following assertions are for efficiency
assert
px_tot
.
stride
()[
-
1
]
==
1
and
px_tot
.
ndim
==
3
assert
py_tot
.
stride
()[
-
1
]
==
1
and
py_tot
.
ndim
==
3
p
=
torch
.
empty
(
B
,
S
+
1
,
T
+
1
,
device
=
px_tot
.
device
,
dtype
=
px_tot
.
dtype
)
# note, tot_probs is without grad.
tot_probs
=
_mutual_information_forward_dispatcher
(
px_tot
,
py_tot
,
boundary
,
p
)
# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad
=
torch
.
ones
(
B
,
device
=
px_tot
.
device
,
dtype
=
px_tot
.
dtype
)
(
px_grad
,
py_grad
)
=
_mutual_information_backward_dispatcher
(
px_tot
,
py_tot
,
boundary
,
p
,
ans_grad
)
px_grad
,
py_grad
=
px_grad
.
reshape
(
1
,
B
,
-
1
),
py_grad
.
reshape
(
1
,
B
,
-
1
)
px_cat
,
py_cat
=
px_cat
.
reshape
(
N
,
B
,
-
1
),
py_cat
.
reshape
(
N
,
B
,
-
1
)
x_prods
=
_inner
(
px_grad
,
px_cat
)
# (N, B)
y_prods
=
_inner
(
py_grad
,
py_cat
)
# (N, B)
# If all the occupation counts were exactly 1.0 (i.e. no partial counts),
# "prods" should be equal to "tot_probs"; however, in general, "tot_probs"
# will be more positive due to the difference between log-of-sum and
# sum-of-log
prods
=
x_prods
+
y_prods
# (N, B)
with
torch
.
no_grad
():
offset
=
tot_probs
-
prods
.
sum
(
dim
=
0
)
# (B,)
prods
[
0
]
+=
offset
return
prods
# (N, B)
torch_mutual_information/mutual_information_cpu.cpp
deleted
100644 → 0
View file @
b5828e2b
#include <math.h> // for log1p, log1pf
#include <torch/extension.h>
inline
double
Exp
(
double
x
)
{
return
exp
(
x
);
}
inline
double
Exp
(
float
x
)
{
return
expf
(
x
);
}
// returns log(exp(x) + exp(y)).
inline
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
}
else
{
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
-
1000
)
{
double
res
;
res
=
x
+
log1p
(
exp
(
diff
));
return
res
;
}
return
x
;
// return the larger one.
}
// returns log(exp(x) + exp(y)).
inline
float
LogAdd
(
float
x
,
float
y
)
{
float
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
}
else
{
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
-
200
)
{
float
res
;
res
=
x
+
log1pf
(
expf
(
diff
));
return
res
;
}
return
x
;
// return the larger one.
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
// px: of shape [B, S, T+1] where
torch
::
Tensor
mutual_information_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
boundary
,
torch
::
Tensor
p
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
boundary
.
dim
()
==
2
,
"boundary must be 2-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
(),
"inputs must be CPU tensors"
);
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
((
boundary
.
size
(
0
)
==
0
&&
boundary
.
size
(
1
)
==
0
)
||
(
boundary
.
size
(
0
)
==
B
&&
boundary
.
size
(
1
)
==
4
));
TORCH_CHECK
(
boundary
.
device
().
is_cpu
()
&&
boundary
.
dtype
()
==
torch
::
kInt64
);
torch
::
Tensor
ans
=
torch
::
empty
({
B
},
opts
);
bool
has_boundary
=
(
boundary
.
size
(
0
)
!=
0
);
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_loop"
,
([
&
]
{
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_a
=
py
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
boundary_a
=
boundary
.
packed_accessor32
<
int64_t
,
2
>
();
auto
ans_a
=
ans
.
packed_accessor32
<
scalar_t
,
1
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
int
s_begin
,
s_end
,
t_begin
,
t_end
;
if
(
has_boundary
)
{
s_begin
=
boundary_a
[
b
][
0
];
t_begin
=
boundary_a
[
b
][
1
];
s_end
=
boundary_a
[
b
][
2
];
t_end
=
boundary_a
[
b
][
3
];
}
else
{
s_begin
=
0
;
t_begin
=
0
;
s_end
=
S
;
t_end
=
T
;
}
p_a
[
b
][
s_begin
][
t_begin
]
=
0.0
;
for
(
int
s
=
s_begin
+
1
;
s
<=
s_end
;
++
s
)
p_a
[
b
][
s
][
t_begin
]
=
p_a
[
b
][
s
-
1
][
t_begin
]
+
px_a
[
b
][
s
-
1
][
t_begin
];
for
(
int
t
=
t_begin
+
1
;
t
<=
t_end
;
++
t
)
p_a
[
b
][
s_begin
][
t
]
=
p_a
[
b
][
s_begin
][
t
-
1
]
+
py_a
[
b
][
s_begin
][
t
-
1
];
for
(
int
s
=
s_begin
+
1
;
s
<=
s_end
;
++
s
)
{
scalar_t
p_s_t1
=
p_a
[
b
][
s
][
t_begin
];
for
(
int
t
=
t_begin
+
1
;
t
<=
t_end
;
++
t
)
{
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a
[
b
][
s
][
t
]
=
p_s_t1
=
LogAdd
(
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
p_s_t1
+
py_a
[
b
][
s
][
t
-
1
]);
}
}
ans_a
[
b
]
=
p_a
[
b
][
s_end
][
t_end
];
}
}));
return
ans
;
}
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cpu
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
)
{
TORCH_CHECK
(
px
.
dim
()
==
3
,
"px must be 3-dimensional"
);
TORCH_CHECK
(
py
.
dim
()
==
3
,
"py must be 3-dimensional."
);
TORCH_CHECK
(
p
.
dim
()
==
3
,
"p must be 3-dimensional."
);
TORCH_CHECK
(
boundary
.
dim
()
==
2
,
"boundary must be 2-dimensional."
);
TORCH_CHECK
(
ans_grad
.
dim
()
==
1
,
"ans_grad must be 3-dimensional."
);
TORCH_CHECK
(
px
.
device
().
is_cpu
()
&&
py
.
device
().
is_cpu
()
&&
p
.
device
().
is_cpu
()
&&
ans_grad
.
device
().
is_cpu
(),
"inputs must be CPU tensors"
);
auto
scalar_t
=
px
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
px
.
device
());
const
int
B
=
px
.
size
(
0
),
S
=
px
.
size
(
1
),
T
=
px
.
size
(
2
)
-
1
;
TORCH_CHECK
(
py
.
size
(
0
)
==
B
&&
py
.
size
(
1
)
==
S
+
1
&&
py
.
size
(
2
)
==
T
);
TORCH_CHECK
(
p
.
size
(
0
)
==
B
&&
p
.
size
(
1
)
==
S
+
1
&&
p
.
size
(
2
)
==
T
+
1
);
TORCH_CHECK
((
boundary
.
size
(
0
)
==
0
&&
boundary
.
size
(
1
)
==
0
)
||
(
boundary
.
size
(
0
)
==
B
&&
boundary
.
size
(
1
)
==
4
));
TORCH_CHECK
(
boundary
.
device
().
is_cpu
()
&&
boundary
.
dtype
()
==
torch
::
kInt64
);
bool
has_boundary
=
(
boundary
.
size
(
0
)
!=
0
);
torch
::
Tensor
p_grad
=
torch
::
zeros
({
B
,
S
+
1
,
T
+
1
},
opts
),
px_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
,
T
+
1
},
opts
)
:
torch
::
empty
({
B
,
S
,
T
+
1
},
opts
)),
py_grad
=
(
has_boundary
?
torch
::
zeros
({
B
,
S
+
1
,
T
},
opts
)
:
torch
::
empty
({
B
,
S
+
1
,
T
},
opts
));
AT_DISPATCH_FLOATING_TYPES
(
px
.
scalar_type
(),
"mutual_information_cpu_backward_loop"
,
([
&
]
{
auto
px_a
=
px
.
packed_accessor32
<
scalar_t
,
3
>
(),
// py_a = py.packed_accessor32<scalar_t, 3>(),
p_a
=
p
.
packed_accessor32
<
scalar_t
,
3
>
(),
p_grad_a
=
p_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
px_grad_a
=
px_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
py_grad_a
=
py_grad
.
packed_accessor32
<
scalar_t
,
3
>
();
auto
ans_grad_a
=
ans_grad
.
packed_accessor32
<
scalar_t
,
1
>
();
auto
boundary_a
=
boundary
.
packed_accessor32
<
int64_t
,
2
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
int
s_begin
,
s_end
,
t_begin
,
t_end
;
if
(
has_boundary
)
{
s_begin
=
boundary_a
[
b
][
0
];
t_begin
=
boundary_a
[
b
][
1
];
s_end
=
boundary_a
[
b
][
2
];
t_end
=
boundary_a
[
b
][
3
];
}
else
{
s_begin
=
0
;
s_end
=
S
;
t_begin
=
0
;
t_end
=
T
;
}
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a
[
b
][
s_end
][
t_end
]
=
ans_grad_a
[
b
];
for
(
int
s
=
s_end
;
s
>
s_begin
;
--
s
)
{
for
(
int
t
=
t_end
;
t
>
t_begin
;
--
t
)
{
// The s,t indexes correspond to
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t
term1
=
p_a
[
b
][
s
-
1
][
t
]
+
px_a
[
b
][
s
-
1
][
t
],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total
=
p_a
[
b
][
s
][
t
],
term1_deriv
=
exp
(
term1
-
total
),
term2_deriv
=
1.0
-
term1_deriv
,
grad
=
p_grad_a
[
b
][
s
][
t
],
term1_grad
=
term1_deriv
*
grad
,
term2_grad
=
term2_deriv
*
grad
;
px_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
p_grad_a
[
b
][
s
-
1
][
t
]
=
term1_grad
;
py_grad_a
[
b
][
s
][
t
-
1
]
=
term2_grad
;
p_grad_a
[
b
][
s
][
t
-
1
]
+=
term2_grad
;
}
}
for
(
int
t
=
t_end
;
t
>
t_begin
;
--
t
)
{
// Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s_begin
][
t
];
p_grad_a
[
b
][
s_begin
][
t
-
1
]
+=
this_p_grad
;
py_grad_a
[
b
][
s_begin
][
t
-
1
]
=
this_p_grad
;
}
for
(
int
s
=
s_end
;
s
>
s_begin
;
--
s
)
{
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t
this_p_grad
=
p_grad_a
[
b
][
s
][
t_begin
];
p_grad_a
[
b
][
s
-
1
][
t_begin
]
+=
this_p_grad
;
px_grad_a
[
b
][
s
-
1
][
t_begin
]
=
this_p_grad
;
}
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if
(
ans_grad_a
[
b
]
!=
0.0
)
{
float
grad_ratio
=
p_grad_a
[
b
][
s_begin
][
t_begin
]
/
ans_grad_a
[
b
];
if
(
fabs
(
grad_ratio
-
1.0
)
>
0.01
)
{
printf
(
"Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f
\n
"
,
(
float
)
p_grad_a
[
b
][
s_begin
][
t_begin
],
(
float
)
ans_grad_a
[
b
]);
}
}
}
}));
// std::cout << "p_grad = " << p_grad;
return
std
::
vector
<
torch
::
Tensor
>
({
px_grad
,
py_grad
});
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"mutual_information_cpu"
,
&
mutual_information_cpu
,
"Integrated convolution forward function (CPU)"
);
m
.
def
(
"mutual_information_backward_cpu"
,
&
mutual_information_backward_cpu
,
"Integrated convolution backward function (CPU)"
);
}
torch_mutual_information/mutual_information_cuda.cpp
deleted
100644 → 0
View file @
b5828e2b
#include <torch/extension.h>
/*
Forward of mutual_information. See also """... """ comment of
`mutual_information` in mutual_information.py. This It is the core recursion
in the sequence-to-sequence mutual information computation.
Args:
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. Alternatively, may be
a tensor of shape [0][0] and type int64_t; the elements will
default to (0, 0, S, T).
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch
::
Tensor
mutual_information_cuda
(
torch
::
Tensor
px
,
// [B][S][T+1]
torch
::
Tensor
py
,
// [B][S+1][T]
torch
::
Tensor
boundary
,
// [B][4], int64_t.
torch
::
Tensor
p
);
// [B][S+1][T+1]; an output
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std
::
vector
<
torch
::
Tensor
>
mutual_information_backward_cuda
(
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
Tensor
boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
,
bool
overwrite_ans_grad
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"mutual_information_cuda"
,
&
mutual_information_cuda
,
"Mutual information forward function (CUDA)"
);
m
.
def
(
"mutual_information_backward_cuda"
,
&
mutual_information_backward_cuda
,
"Mutual information backward function (CUDA)"
);
}
torch_mutual_information/mutual_information_test.py
deleted
100644 → 0
View file @
b5828e2b
# Caution: this will fail occasionally due to cutoffs not being quite large enough.
# As long as it passes most of the time, it's OK.
import
random
import
torch
from
torch_mutual_information
import
mutual_information_recursion
,
joint_mutual_information_recursion
def
test_mutual_information_basic
():
print
(
"Running test_mutual_information_basic()"
)
for
_iter
in
range
(
100
):
(
B
,
S
,
T
)
=
(
random
.
randint
(
1
,
10
),
random
.
randint
(
1
,
200
),
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.2
)
random_py
=
(
random
.
random
()
<
0.2
)
random_boundary
=
(
random
.
random
()
<
0.7
)
big_px
=
(
random
.
random
()
<
0.2
)
big_py
=
(
random
.
random
()
<
0.2
)
print
(
f
"B, S, T =
{
B
}
,
{
S
}
,
{
T
}
, random_px=
{
random_px
}
, random_py=
{
random_py
}
, big_px=
{
big_px
}
, big_py=
{
big_py
}
, random_boundary=
{
random_boundary
}
"
)
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
px_grads
=
[]
py_grads
=
[]
m_vals
=
[]
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
if
random_boundary
:
def
get_boundary_row
():
s_begin
=
random
.
randint
(
0
,
S
-
1
)
t_begin
=
random
.
randint
(
0
,
T
-
1
)
s_end
=
random
.
randint
(
s_begin
,
S
)
# allow empty sequence
t_end
=
random
.
randint
(
t_begin
,
T
)
# allow empty sequence
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
if
device
==
torch
.
device
(
'cpu'
):
boundary
=
torch
.
tensor
([
get_boundary_row
()
for
_
in
range
(
B
)
],
dtype
=
torch
.
int64
,
device
=
device
)
else
:
boundary
=
boundary
.
to
(
device
)
else
:
# Use default boundary, but either specified directly or not.
if
random
.
random
()
<
0.5
:
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
else
:
boundary
=
None
if
device
==
torch
.
device
(
'cpu'
):
if
random_px
:
px
=
torch
.
randn
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if
big_px
:
px
+=
15.0
if
random_py
:
py
=
torch
.
randn
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
if
big_py
:
py
+=
15.0
else
:
px
=
px
.
to
(
device
).
detach
()
py
=
py
.
to
(
device
).
detach
()
px
.
requires_grad
=
True
py
.
requires_grad
=
True
#m = mutual_information_recursion(px, py, None)
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
m2
=
joint_mutual_information_recursion
((
px
,),
(
py
,),
boundary
)
m3
=
joint_mutual_information_recursion
((
px
*
0.5
,
px
*
0.5
),
(
py
*
0.5
,
py
*
0.5
),
boundary
)
print
(
"m3, before sum, = "
,
m3
)
m3
=
m3
.
sum
(
dim
=
0
)
# it is supposed to be identical only after
# summing over dim 0, corresponding to the
# sequence dim
print
(
"m = "
,
m
,
", size = "
,
m
.
shape
)
print
(
"m2 = "
,
m2
,
", size = "
,
m2
.
shape
)
print
(
"m3 = "
,
m3
,
", size = "
,
m3
.
shape
)
assert
torch
.
allclose
(
m
,
m2
)
assert
torch
.
allclose
(
m
,
m3
)
#print("exp(m) = ", m.exp())
# the loop this is in checks that the CPU and CUDA versions give the same
# derivative; by randomizing which of m, m2 or m3 we backprop, we also
# ensure that the joint version of the code gives the same derivative
# as the regular version
scale
=
3
if
random
.
random
()
<
0.5
:
(
m
.
sum
()
*
scale
).
backward
()
elif
random
.
random
()
<
0.5
:
(
m2
.
sum
()
*
scale
).
backward
()
else
:
(
m3
.
sum
()
*
scale
).
backward
()
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
px_grads
.
append
(
px
.
grad
.
to
(
'cpu'
))
py_grads
.
append
(
py
.
grad
.
to
(
'cpu'
))
m_vals
.
append
(
m
.
to
(
'cpu'
))
if
not
torch
.
allclose
(
m_vals
[
0
],
m_vals
[
1
],
atol
=
1.0e-02
,
rtol
=
1.0e-02
):
print
(
f
"m_vals differed CPU vs CUDA:
{
m_vals
[
0
]
}
vs.
{
m_vals
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
px_grads
[
0
],
px_grads
[
1
],
atol
=
1.0e-02
,
rtol
=
1.0e-02
):
print
(
f
"px_grads differed CPU vs CUDA:
{
px_grads
[
0
]
}
vs.
{
px_grads
[
1
]
}
"
)
assert
0
if
not
torch
.
allclose
(
py_grads
[
0
],
py_grads
[
1
],
atol
=
1.0e-02
,
rtol
=
1.0e-02
):
print
(
f
"py_grads differed CPU vs CUDA:
{
py_grads
[
0
]
}
vs.
{
py_grads
[
1
]
}
"
)
assert
0
def
test_mutual_information_deriv
():
print
(
"Running test_mutual_information_deriv()"
)
for
_iter
in
range
(
100
):
(
B
,
S
,
T
)
=
(
random
.
randint
(
1
,
10
),
random
.
randint
(
1
,
200
),
random
.
randint
(
1
,
200
))
random_px
=
(
random
.
random
()
<
0.2
)
random_py
=
(
random
.
random
()
<
0.2
)
random_boundary
=
(
random
.
random
()
<
0.7
)
big_px
=
(
random
.
random
()
<
0.2
)
big_py
=
(
random
.
random
()
<
0.2
)
print
(
f
"B, S, T =
{
B
}
,
{
S
}
,
{
T
}
, random_px=
{
random_px
}
, random_py=
{
random_py
}
, big_px=
{
big_px
}
, big_py=
{
big_py
}
, random_boundary=
{
random_boundary
}
"
)
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
#px_grads = []
#py_grads = []
#m_vals = []
for
device
in
[
torch
.
device
(
'cpu'
),
torch
.
device
(
'cuda:0'
)
]:
print
(
"dtype = "
,
dtype
,
", device = "
,
device
)
if
random_boundary
:
def
get_boundary_row
():
s_begin
=
random
.
randint
(
0
,
S
-
1
)
t_begin
=
random
.
randint
(
0
,
T
-
1
)
s_end
=
random
.
randint
(
s_begin
+
1
,
S
)
t_end
=
random
.
randint
(
t_begin
+
1
,
T
)
return
[
s_begin
,
t_begin
,
s_end
,
t_end
]
if
device
==
torch
.
device
(
'cpu'
):
boundary
=
torch
.
tensor
([
get_boundary_row
()
for
_
in
range
(
B
)
],
dtype
=
torch
.
int64
,
device
=
device
)
else
:
boundary
=
boundary
.
to
(
device
)
else
:
# Use default boundary, but either specified directly or not.
if
random
.
random
()
<
0.5
:
boundary
=
torch
.
tensor
([
0
,
0
,
S
,
T
],
dtype
=
torch
.
int64
).
unsqueeze
(
0
).
expand
(
B
,
4
).
to
(
device
)
else
:
boundary
=
None
if
device
==
torch
.
device
(
'cpu'
):
if
random_px
:
px
=
torch
.
randn
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
px
=
torch
.
zeros
(
B
,
S
,
T
+
1
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if
big_px
:
px
+=
15.0
if
random_py
:
py
=
torch
.
randn
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
else
:
py
=
torch
.
zeros
(
B
,
S
+
1
,
T
,
dtype
=
dtype
).
to
(
device
)
# log of an odds ratio
if
big_py
:
py
+=
15.0
else
:
px
=
px
.
to
(
device
).
detach
()
py
=
py
.
to
(
device
).
detach
()
px
.
requires_grad
=
True
py
.
requires_grad
=
True
m
=
mutual_information_recursion
(
px
,
py
,
boundary
)
#print("m = ", m)
#print("exp(m) = ", m.exp())
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
#px_grads.append(px.grad.to('cpu'))
#py_grads.append(py.grad.to('cpu'))
#m_vals.append(m.to('cpu'))
m_grad
=
torch
.
randn
(
B
,
dtype
=
dtype
,
device
=
device
)
m
.
backward
(
gradient
=
m_grad
)
delta
=
1.0e-04
delta_px
=
delta
*
torch
.
randn_like
(
px
)
m2
=
mutual_information_recursion
(
px
+
delta_px
,
py
,
boundary
)
delta_m
=
m2
-
m
observed_delta
=
(
delta_m
*
m_grad
).
sum
().
to
(
'cpu'
)
predicted_delta
=
(
delta_px
*
px
.
grad
).
sum
().
to
(
'cpu'
)
print
(
f
"For px: observed,predicted objf changes are:
{
observed_delta
}
,
{
predicted_delta
}
, absolute objf was
{
(
m
*
m_grad
).
sum
()
}
"
)
atol
=
1.0e-02
if
dtype
==
torch
.
float32
else
1.0e-04
rtol
=
1.0e-02
if
dtype
==
torch
.
float32
else
1.0e-04
if
not
torch
.
allclose
(
observed_delta
,
predicted_delta
,
atol
=
atol
,
rtol
=
rtol
):
print
(
f
"Error: observed and predicted delta too different."
)
assert
0
delta_py
=
delta
*
torch
.
randn_like
(
py
)
m2
=
mutual_information_recursion
(
px
,
py
+
delta_py
,
boundary
)
delta_m
=
m2
-
m
observed_delta
=
(
delta_m
*
m_grad
).
sum
().
to
(
'cpu'
)
predicted_delta
=
(
delta_py
*
py
.
grad
).
sum
().
to
(
'cpu'
)
print
(
f
"For py: observed,predicted objf changes are:
{
observed_delta
}
,
{
predicted_delta
}
, absolute objf was
{
(
m
*
m_grad
).
sum
()
}
"
)
# if not torch.allclose(m_vals[0], m_vals[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"m_vals differed CPU vs CUDA: {m_vals[0]} vs. {m_vals[1]}")
# assert 0
# if not torch.allclose(px_grads[0], px_grads[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
# assert 0
# if not torch.allclose(py_grads[0], py_grads[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
# assert 0
if
__name__
==
"__main__"
:
#torch.set_printoptions(edgeitems=30)
test_mutual_information_basic
()
test_mutual_information_deriv
()
torch_mutual_information/rnnt.py
deleted
100644 → 0
View file @
b5828e2b
import
os
import
torch
from
torch
import
Tensor
from
typing
import
Tuple
,
Optional
from
.
mutual_information
import
mutual_information_recursion
,
joint_mutual_information_recursion
def
get_rnnt_logprobs
(
lm
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
)
->
Tuple
[
Tensor
,
Tensor
]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just addition),
to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion(). This function is called from
rnnt_loss_simple(), but may be useful for other purposes.
Args:
lm: Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape:
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am: Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape:
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols: A LongTensor of shape [B][S], containing the symbols at each position
of the sequence, possibly including EOS
termination_symbol: The identity of the termination symbol, must be
in {0..C-1}
Returns: (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1]
py: logprobs, of shape [B][S+1][T]
in the recursion:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert
lm
.
ndim
==
3
and
am
.
ndim
==
3
and
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
and
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
# subtracting am_max and lm_max is to ensure the probs are in a good range to do exp()
# without causing underflow or overflow.
am_max
,
_
=
torch
.
max
(
am
,
dim
=
2
,
keepdim
=
True
)
# am_max: [B][T][1]
lm_max
,
_
=
torch
.
max
(
lm
,
dim
=
2
,
keepdim
=
True
)
# lm_max: [B][S+1][1]
am_probs
=
(
am
-
am_max
).
exp
()
lm_probs
=
(
lm
-
lm_max
).
exp
()
# normalizers: [B][S+1][T]
normalizers
=
(
torch
.
matmul
(
lm_probs
,
am_probs
.
transpose
(
1
,
2
))
+
1.0e-20
).
log
()
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers
=
normalizers
+
lm_max
+
am_max
.
transpose
(
1
,
2
)
# [B][S+1][T]
# px is the probs of the actual symbols..
px_am
=
torch
.
gather
(
am
.
unsqueeze
(
1
).
expand
(
B
,
S
,
T
,
C
),
dim
=
3
,
index
=
symbols
.
reshape
(
B
,
S
,
1
,
1
).
expand
(
B
,
S
,
T
,
1
)).
squeeze
(
-
1
)
# [B][S][T]
px_am
=
torch
.
cat
((
px_am
,
torch
.
full
((
B
,
S
,
1
),
float
(
'-inf'
),
device
=
px_am
.
device
,
dtype
=
px_am
.
dtype
)),
dim
=
2
)
# now: [B][S][T+1], index [:,:,T] has -inf..
px_lm
=
torch
.
gather
(
lm
[:,:
S
],
dim
=
2
,
index
=
symbols
.
unsqueeze
(
-
1
))
# [B][S][1]
px
=
px_am
+
px_lm
# [B][S][T+1], last slice indexed [:,:,T] is -inf
px
[:,:,:
T
]
-=
normalizers
[:,:
S
,:]
# px: [B][S][T+1]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am
=
am
[:,:,
termination_symbol
].
unsqueeze
(
1
)
# [B][1][T]
py_lm
=
lm
[:,:,
termination_symbol
].
unsqueeze
(
2
)
# [B][S+1][1]
py
=
py_am
+
py_lm
-
normalizers
return
(
px
,
py
)
def
rnnt_loss_simple
(
lm
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
boundary
:
Tensor
=
None
)
->
Tensor
:
"""
A simple case of the RNN-T loss, where the 'joiner' network is just addition.
Returns negated total loss value.
Args:
lm: language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes
am: acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols: the symbol sequences, a LongTensor of shape [B][S], and elements in {0..C-1}.
termination_symbol: the termination symbol, with 0 <= termination_symbol < C
boundary: a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Returns:
a Tensor of shape (B,), containing the NEGATED total RNN-T loss values
for each element of the batch (like log-probs of sequences).
"""
px
,
py
=
get_rnnt_logprobs
(
lm
,
am
,
symbols
,
termination_symbol
)
return
mutual_information_recursion
(
px
,
py
,
boundary
)
def
get_rnnt_logprobs_aux
(
lm
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
lm_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
)
->
Tuple
[
Tensor
,
Tensor
]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just addition),
to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion(). This version allows you
to make the loss-function one of the form:
lm_only_scale * lm_probs +
am_only_scale * am_probs +
(1-lm_only_scale-am_only_scale) * combined_probs
where lm_probs and am_probs are the probabilities given the lm and acoustic model
independently.
This function is called from
rnnt_loss_aux(), but may be useful for other purposes.
Args:
lm: Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape:
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am: Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape:
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols: A LongTensor of shape [B][S], containing the symbols at each position
of the sequence, possibly including EOS
termination_symbol: The identity of the termination symbol, must be
in {0..C-1}
Returns: (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1]
py: logprobs, of shape [B][S+1][T]
in the recursion:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert
lm
.
ndim
==
3
and
am
.
ndim
==
3
and
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
and
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
# all of the logprobs here are properly normalized. We test that
# this code is invariant to adding constants in the appropriate ways.
# subtracting am_max and lm_max is to ensure the probs are in a good range to do exp()
# without causing underflow or overflow.
am_max
,
_
=
torch
.
max
(
am
,
dim
=
2
,
keepdim
=
True
)
# am_max: [B][T][1]
lm_max
,
_
=
torch
.
max
(
lm
,
dim
=
2
,
keepdim
=
True
)
# lm_max: [B][S+1][1]
am_probs
=
(
am
-
am_max
).
exp
()
# [B][T][C]
lm_probs
=
(
lm
-
lm_max
).
exp
()
# [B][S+1][C]
# normalizers: [B][S+1][T]
normalizers
=
(
torch
.
matmul
(
lm_probs
,
am_probs
.
transpose
(
1
,
2
))
+
1.0e-20
).
log
()
# normalizer per frame, if we take only the LM probs by themselves
lmonly_normalizers
=
lm_probs
.
sum
(
dim
=
2
,
keepdim
=
True
)
# lmonly_normalizers: [B][S+1][1]
unigram_lm
=
torch
.
mean
(
lm_probs
/
lmonly_normalizers
,
dim
=
(
0
,
1
),
keepdim
=
True
)
+
1.0e-20
# [1][1][C]
amonly_normalizers
=
torch
.
mv
(
am_probs
.
reshape
(
-
1
,
C
),
unigram_lm
.
reshape
(
C
)).
reshape
(
B
,
T
,
1
).
log
()
+
am_max
# [B][T][1]
amonly_normalizers
=
amonly_normalizers
.
transpose
(
1
,
2
)
# [B][1][T]
unigram_lm
=
unigram_lm
.
log
()
lmonly_normalizers
=
lmonly_normalizers
.
log
()
+
lm_max
# [B][S+1][1], log-normalizer, used for LM-only part of prob.
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers
=
normalizers
+
lm_max
+
am_max
.
transpose
(
1
,
2
)
# [B][S+1][T]
# px is the probs of the actual symbols (not yet normalized)..
px_am
=
torch
.
gather
(
am
.
unsqueeze
(
1
).
expand
(
B
,
S
,
T
,
C
),
dim
=
3
,
index
=
symbols
.
reshape
(
B
,
S
,
1
,
1
).
expand
(
B
,
S
,
T
,
1
)).
squeeze
(
-
1
)
# [B][S][T]
px_am
=
torch
.
cat
((
px_am
,
torch
.
full
((
B
,
S
,
1
),
float
(
'-inf'
),
device
=
px_am
.
device
,
dtype
=
px_am
.
dtype
)),
dim
=
2
)
# now: [B][S][T+1], index [:,:,T] has -inf..
px_lm
=
torch
.
gather
(
lm
[:,:
S
],
dim
=
2
,
index
=
symbols
.
unsqueeze
(
-
1
))
# [B][S][1]
px_lm_unigram
=
torch
.
gather
(
unigram_lm
.
expand
(
B
,
S
,
C
),
dim
=
2
,
index
=
symbols
.
unsqueeze
(
-
1
))
# [B][S][1]
px
=
px_am
+
px_lm
# [B][S][T+1], last slice indexed [:,:,T] is -inf
px
[:,:,:
T
]
-=
normalizers
[:,:
S
,:]
# px: [B][S][T+1]
px_amonly
=
px_am
+
px_lm_unigram
# [B][S][T+1]
px_amonly
[:,:,:
T
]
-=
amonly_normalizers
px_lmonly
=
px_lm
-
lmonly_normalizers
[:,:
S
,:]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am
=
am
[:,:,
termination_symbol
].
unsqueeze
(
1
)
# [B][1][T]
py_lm
=
lm
[:,:,
termination_symbol
].
unsqueeze
(
2
)
# [B][S+1][1]
py
=
py_am
+
py_lm
-
normalizers
py_lm_unigram
=
unigram_lm
[
0
][
0
][
termination_symbol
]
# scalar, normalized..
py_amonly
=
py_am
+
py_lm_unigram
-
amonly_normalizers
# [B][S+1][T]
py_lmonly
=
py_lm
-
lmonly_normalizers
# [B][S+1][T]
combined_scale
=
1.0
-
lm_only_scale
-
am_only_scale
# We need to avoid exact zeros in the scales because otherwise multiplying -inf
# by zero generates nan.
if
lm_only_scale
==
0.0
:
lm_only_scale
=
1.0e-20
if
am_only_scale
==
0.0
:
am_only_scale
=
1.0e-20
px_interp
=
px
*
combined_scale
+
px_lmonly
*
lm_only_scale
+
px_amonly
*
am_only_scale
py_interp
=
py
*
combined_scale
+
py_lmonly
*
lm_only_scale
+
py_amonly
*
am_only_scale
print
(
"px_interp = "
,
px_interp
)
print
(
"py_interp = "
,
py_interp
)
return
(
px_interp
,
py_interp
)
def
rnnt_loss_aux
(
lm
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
lm_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
boundary
:
Tensor
=
None
)
->
Tensor
:
"""
A simple case of the RNN-T loss, where the 'joiner' network is just addition.
Returns negated total loss value.
Args:
lm: language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes.
These are assumed to be well-normalized, in the sense that we could
use them as probabilities separately from the am scores
am: acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols: the symbol sequences, a LongTensor of shape [B][S], and elements in {0..C-1}.
termination_symbol: the termination symbol, with 0 <= termination_symbol < C
am_only_scale: the scale on the "AM-only" part of the loss, for which we use
an "averaged" LM (averaged over all histories, so effectively unigram).
boundary: a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Returns:
a Tensor of shape (B,), containing the NEGATED total RNN-T loss values
for each element of the batch (like log-probs of sequences).
"""
px
,
py
=
get_rnnt_logprobs_aux
(
lm
,
am
,
symbols
,
termination_symbol
,
lm_only_scale
,
am_only_scale
)
return
mutual_information_recursion
(
px
,
py
,
boundary
)
torch_mutual_information/rnnt_test.py
deleted
100644 → 0
View file @
b5828e2b
import
random
import
torch
from
torch_mutual_information
import
mutual_information_recursion
,
joint_mutual_information_recursion
,
get_rnnt_logprobs
,
rnnt_loss_simple
,
rnnt_loss_aux
def
test_rnnt_logprobs_basic
():
print
(
"Running test_rnnt_logprobs_basic()"
)
B
=
1
S
=
3
T
=
4
C
=
3
# lm: [B][S+1][C]
lm
=
torch
.
tensor
([[[
0
,
0
,
1
],
[
0
,
1
,
1
],
[
1
,
0
,
1
],
[
2
,
2
,
0
]]],
dtype
=
torch
.
float
)
# am: [B][T][C]
am
=
torch
.
tensor
([[[
0
,
1
,
2
],
[
0
,
0
,
0
],
[
0
,
2
,
4
],
[
0
,
3
,
3
]]],
dtype
=
torch
.
float
)
# lm[:] = 0.0
# am[:] = 0.0
termination_symbol
=
2
symbols
=
torch
.
tensor
([[
0
,
1
,
0
]
],
dtype
=
torch
.
long
)
px
,
py
=
get_rnnt_logprobs
(
lm
,
am
,
symbols
,
termination_symbol
)
assert
px
.
shape
==
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
print
(
"px = "
,
px
)
print
(
"py = "
,
py
)
m
=
mutual_information_recursion
(
px
,
py
)
print
(
"m = "
,
m
)
# should be invariant to adding a constant for any frame.
lm
+=
torch
.
randn
(
B
,
S
+
1
,
1
)
am
+=
torch
.
randn
(
B
,
T
,
1
)
m2
=
rnnt_loss_simple
(
lm
,
am
,
symbols
,
termination_symbol
,
None
)
print
(
"m2 = "
,
m2
)
device
=
torch
.
device
(
'cuda'
)
m3
=
rnnt_loss_simple
(
lm
.
to
(
device
),
am
.
to
(
device
),
symbols
.
to
(
device
),
termination_symbol
,
None
)
print
(
"m3 = "
,
m3
)
device
=
torch
.
device
(
'cuda'
)
m4
=
rnnt_loss_aux
(
lm
.
to
(
device
),
am
.
to
(
device
),
symbols
.
to
(
device
),
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
None
)
print
(
"m4 = "
,
m4
)
assert
torch
.
allclose
(
m
,
m2
)
assert
torch
.
allclose
(
m
,
m3
.
to
(
'cpu'
))
assert
torch
.
allclose
(
m
,
m4
.
to
(
'cpu'
))
def
test_rnnt_logprobs_aux
():
print
(
"Running test_rnnt_logprobs_aux()"
)
B
=
1
S
=
3
T
=
4
C
=
3
# lm: [B][S+1][C]
lm
=
torch
.
tensor
([[[
0
,
0
,
1
],
[
0
,
1
,
1
],
[
1
,
0
,
1
],
[
2
,
2
,
0
]]],
dtype
=
torch
.
float
)
# am: [B][T][C]
am
=
torch
.
tensor
([[[
0
,
1
,
2
],
[
0
,
0
,
0
],
[
0
,
2
,
4
],
[
0
,
3
,
3
]]],
dtype
=
torch
.
float
)
termination_symbol
=
2
symbols
=
torch
.
tensor
([[
0
,
1
,
0
]
],
dtype
=
torch
.
long
)
device
=
torch
.
device
(
'cuda'
)
m1
=
rnnt_loss_aux
(
lm
.
to
(
device
),
am
.
to
(
device
),
symbols
.
to
(
device
),
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.333
,
boundary
=
None
)
print
(
"m1 = "
,
m1
)
# should be invariant to adding a constant for any frame.
lm
+=
torch
.
randn
(
B
,
S
+
1
,
1
)
am
+=
torch
.
randn
(
B
,
T
,
1
)
m2
=
rnnt_loss_aux
(
lm
.
to
(
device
),
am
.
to
(
device
),
symbols
.
to
(
device
),
termination_symbol
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.333
,
boundary
=
None
)
print
(
"m2 = "
,
m2
)
assert
torch
.
allclose
(
m1
,
m2
)
if
__name__
==
"__main__"
:
#torch.set_printoptions(edgeitems=30)
test_rnnt_logprobs_aux
()
test_rnnt_logprobs_basic
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment