Unverified Commit a283cddf authored by Daniel Povey's avatar Daniel Povey Committed by GitHub
Browse files

Merge pull request #4 from pkufool/fast_rnnt

Sync with k2 rnnt_loss
parents b5828e2b 182fe8de
include_directories(${CMAKE_SOURCE_DIR})
# it is located in fast_rnnt/cmake/transform.cmake
include(transform)
set(srcs
mutual_information_cpu.cu
utils.cu
)
if(NOT FT_WITH_CUDA)
transform(OUTPUT_VARIABLE srcs SRCS ${srcs})
else()
list(APPEND srcs mutual_information_cuda.cu)
endif()
add_library(mutual_information_core ${srcs})
target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES})
# for <torch/extension.h>
target_include_directories(mutual_information_core PUBLIC ${PYTHON_INCLUDE_DIRS})
# see https://github.com/NVIDIA/thrust/issues/1401
# and https://github.com/k2-fsa/k2/pull/917
target_compile_definitions(mutual_information_core PUBLIC CUB_WRAPPED_NAMESPACE=fast_rnnt)
target_compile_definitions(mutual_information_core PUBLIC THRUST_NS_QUALIFIER=thrust)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
target_link_libraries(mutual_information_core PUBLIC cub)
endif()
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang, Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_DEVICE_GUARD_H_
#define FAST_RNNT_CSRC_DEVICE_GUARD_H_
#include "torch/script.h"
// This file is modified from
// https://github.com/k2-fsa/k2/blob/master/k2/csrc/device_guard.h
namespace fast_rnnt {
// DeviceGuard is an RAII class. Its sole purpose is to restore
// the previous default cuda device if a CUDA context changes the
// current default cuda device.
class DeviceGuard {
public:
explicit DeviceGuard(torch::Device device) {
if (device.type() == torch::kCUDA) {
old_device_ = GetDevice();
new_device_ = device.index();
if (old_device_ != new_device_)
SetDevice(new_device_);
}
// else do nothing
}
explicit DeviceGuard(int32_t new_device) : new_device_(new_device) {
if (new_device != -1) {
old_device_ = GetDevice();
if (old_device_ != new_device)
SetDevice(new_device);
}
}
~DeviceGuard() {
if (old_device_ != -1 && old_device_ != new_device_) {
// restore the previous device
SetDevice(old_device_);
}
// else it was either a CPU context or the device IDs
// were the same
}
DeviceGuard(const DeviceGuard &) = delete;
DeviceGuard &operator=(const DeviceGuard &) = delete;
DeviceGuard(DeviceGuard &&) = delete;
DeviceGuard &operator=(DeviceGuard &&) = delete;
private:
static int32_t GetDevice() {
#ifdef FT_WITH_CUDA
int32_t device;
auto s = cudaGetDevice(&device);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
return device;
#else
return -1;
#endif
}
static void SetDevice(int32_t device) {
#ifdef FT_WITH_CUDA
auto s = cudaSetDevice(device);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
return;
#endif
}
private:
int32_t old_device_ = -1;
int32_t new_device_ = -1;
};
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_DEVICE_GUARD_H_
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#include <cmath>
#include <vector>
#include "torch/extension.h"
#ifdef __CUDA_ARCH__
#define FT_CUDA_HOSTDEV __host__ __device__
#else
#define FT_CUDA_HOSTDEV
#endif
namespace fast_rnnt {
FT_CUDA_HOSTDEV inline double LogAdd(double x, double y) {
double diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
// returns log(exp(x) + exp(y)).
FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
/*
Forward of mutual_information. See also comment of `mutual_information`
in ../pyhton/fast_rnnt/mutual_information.py. This is the core recursion
in the sequence-to-sequence mutual information computation.
@param px Tensor of shape [B][S][T + 1] if not modified, [B][S][T] if
modified. `modified` can be worked out from this. In not-modified case,
it can be thought of as the log-odds ratio of generating the next x in
the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of
lengths (s, t), divided by the prior probability of generating x_s.
(See mutual_information.py for more info).
@param py The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
@param p This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
if not modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
... treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b]
equals [s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last)
of the x and y sequences that we should process.
Alternatively, may be a tensor of shape [0][0] and type
int64_t; the elements will default to (0, 0, S, T).
@return A tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor MutualInformationCpu(
torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
torch::Tensor MutualInformationCuda(
torch::Tensor px, // [B][S][T+1] if !modified, [B][S][T] if modified.
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor>
MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad);
std::vector<torch::Tensor> MutualInformationBackwardCuda(
torch::Tensor px, torch::Tensor py, torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "fast_rnnt/csrc/mutual_information.h"
namespace fast_rnnt {
// forward of mutual_information. See """... """ comment of
// `mutual_information_recursion` in
// in k2/python/k2/mutual_information.py for documentation of the
// behavior of this function.
// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
// `modified` from this.
// py: of shape [B, S+1, T]
// boundary: of shape [B, 4], containing (s_begin, t_begin, s_end, t_end)
// defaulting to (0, 0, S, T).
// p: of shape (S+1, T+1)
// Computes the recursion:
// if !modified:
// p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
// p[b,s,t-1] + py[b,s,t-1])
// if modified:
// p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
// p[b,s,t-1] + py[b,s,t-1])
// .. treating out-of-range elements as -infinity and with special cases:
// p[b, s_begin, t_begin] = 0.0
//
// and this function returns a tensor of shape (B,) consisting of elements
// p[b, s_end, t_end]
torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
p.device().is_cpu(),
"inputs must be CPU tensors");
bool modified = (px.size(2) == py.size(2));
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), S = px.size(1), T = py.size(2);
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, T},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts);
AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(),
py_a = py.accessor<scalar_t, 3>(), p_a = p.accessor<scalar_t, 3>();
auto boundary_a = boundary.accessor<int64_t, 2>();
auto ans_a = ans.accessor<scalar_t, 1>();
int t_offset = (modified ? -1 : 0);
for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int t_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int t_end = boundary_a[b][3];
p_a[b][s_begin][t_begin] = 0.0;
if (modified) {
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] = -std::numeric_limits<scalar_t>::infinity();
} else {
// note: t_offset = 0 so don't need t_begin + t_offset below.
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] =
p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
}
for (int t = t_begin + 1; t <= t_end; ++t)
p_a[b][s_begin][t] =
p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
for (int s = s_begin + 1; s <= s_end; ++s) {
scalar_t p_s_t1 = p_a[b][s][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t) {
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(
// p_a[b][s - 1][t + t_offset] + px_a[b][s -1][t + t_offset],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a[b][s][t] = p_s_t1 = LogAdd(p_a[b][s - 1][t + t_offset] +
px_a[b][s - 1][t + t_offset],
p_s_t1 + py_a[b][s][t - 1]);
}
}
ans_a[b] = p_a[b][s_end][t_end];
}
}));
return ans;
}
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std::vector<torch::Tensor>
MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary,
torch::Tensor p, torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional.");
bool modified = (px.size(2) == py.size(2));
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
p.device().is_cpu() && ans_grad.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), S = px.size(1), T = py.size(2);
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, T},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
bool has_boundary = opt_boundary.has_value();
int T1 = T + (modified ? 0 : 1);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts)
: torch::empty({B, S, T1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts)
: torch::empty({B, S + 1, T}, opts));
AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(), p_a = p.accessor<scalar_t, 3>(),
p_grad_a = p_grad.accessor<scalar_t, 3>(),
px_grad_a = px_grad.accessor<scalar_t, 3>(),
py_grad_a = py_grad.accessor<scalar_t, 3>();
auto ans_grad_a = ans_grad.accessor<scalar_t, 1>();
auto boundary_a = boundary.accessor<int64_t, 2>();
int t_offset = (modified ? -1 : 0);
for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int t_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int t_end = boundary_a[b][3];
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a[b][s_end][t_end] = ans_grad_a[b];
for (int s = s_end; s > s_begin; --s) {
for (int t = t_end; t > t_begin; --t) {
// The s,t indexes correspond to
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(
// p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t + t_offset] +
px_a[b][s - 1][t + t_offset],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total = p_a[b][s][t];
if (total - total != 0)
total = 0;
scalar_t term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t];
scalar_t term1_grad, term2_grad;
if (term1_deriv - term1_deriv == 0.0) {
term1_grad = term1_deriv * grad;
term2_grad = term2_deriv * grad;
} else {
// could happen if total == -inf
term1_grad = term2_grad = 0.0;
}
px_grad_a[b][s - 1][t + t_offset] = term1_grad;
p_grad_a[b][s - 1][t + t_offset] = term1_grad;
py_grad_a[b][s][t - 1] = term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
}
}
for (int t = t_end; t > t_begin; --t) {
// Backprop for:
// p_a[b][s_begin][t] =
// p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t this_p_grad = p_grad_a[b][s_begin][t];
p_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] = this_p_grad;
}
if (!modified) {
for (int s = s_end; s > s_begin; --s) {
// Backprop for:
// p_a[b][s][t_begin] =
// p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t this_p_grad = p_grad_a[b][s][t_begin];
p_grad_a[b][s - 1][t_begin] += this_p_grad;
px_grad_a[b][s - 1][t_begin] = this_p_grad;
}
} // else these were all -infinity's and there is nothing to
// backprop.
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b];
if (fabs(grad_ratio - 1.0) > 0.01) {
std::cout
<< "Warning: mutual_information backprop: expected these "
<< "numbers to be the same:"
<< static_cast<float>(p_grad_a[b][s_begin][t_begin]) << " vs "
<< static_cast<float>(ans_grad_a[b]);
}
}
}
}));
return std::vector<torch::Tensor>({px_grad, py_grad});
}
} // namespace fast_rnnt
#include <torch/extension.h> /**
#include <c10/cuda/CUDAStream.h> // for getCurrentCUDAStream() * @copyright
#include <cooperative_groups.h> * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
#include <cmath> // for INFINITY *
* @copyright
* See LICENSE for clarification regarding multiple authors
// returns log(exp(x) + exp(y)). *
__forceinline__ __device__ double LogAdd(double x, double y) { * Licensed under the Apache License, Version 2.0 (the "License");
double diff; * you may not use this file except in compliance with the License.
if (x < y) { * You may obtain a copy of the License at
diff = x - y; *
x = y; * http://www.apache.org/licenses/LICENSE-2.0
} else { *
diff = y - x; * Unless required by applicable law or agreed to in writing, software
} * distributed under the License is distributed on an "AS IS" BASIS,
// diff is negative. x is now the larger one. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
if (diff - diff != 0) * See the License for the specific language governing permissions and
return x; // x and y are probably -inf. Return the larger one. * limitations under the License.
else */
return x + log1p(exp(diff));
}
// returns log(exp(x) + exp(y)).
__forceinline__ __device__ float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
#include <c10/cuda/CUDAStream.h> // for getCurrentCUDAStream()
#include <cooperative_groups.h>
#include "fast_rnnt/csrc/mutual_information.h"
namespace fast_rnnt {
/* /*
Forward of mutual_information. Each thread block computes blocks of the 'p' Forward of mutual_information. Each thread block computes blocks of the 'p'
array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32). array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32).
...@@ -55,13 +40,14 @@ __forceinline__ __device__ float LogAdd(float x, float y) { ...@@ -55,13 +40,14 @@ __forceinline__ __device__ float LogAdd(float x, float y) {
is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading
code). code).
Args: Args:
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of px: Tensor of shape [B][S][T + 1], if !modified; [B][S][T] if modified;
generating the next x in the sequence, i.e. may be interpreted as the log-odds ratio of
xy[b][s][t] is the log of generating the next x in the sequence, i.e.
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s), xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See (s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info). mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence. py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T] Shape [B][S + 1][T]
p: This function writes to p[b][s][t] the mutual information between p: This function writes to p[b][s][t] the mutual information between
...@@ -71,10 +57,14 @@ __forceinline__ __device__ float LogAdd(float x, float y) { ...@@ -71,10 +57,14 @@ __forceinline__ __device__ float LogAdd(float x, float y) {
in the case where s_begin == t_begin == 0: in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0 p[b,0,0] = 0.0
if not `modified`:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1]) (eq. 0) p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
if s > 0 or t > 0, if `modified`:
treating values with any -1 index as -infinity. p[b,s,t] = log_add(p[b,s-1,t-t] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0. .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals contains, where for each batch element b, boundary[b] equals
...@@ -95,29 +85,32 @@ __forceinline__ __device__ float LogAdd(float x, float y) { ...@@ -95,29 +85,32 @@ __forceinline__ __device__ float LogAdd(float x, float y) {
be at least 128. be at least 128.
*/ */
template <typename scalar_t, template <typename scalar_t,
int BLOCK_SIZE> // e.g. BLOCK_SIZE == 16 or 32. int BLOCK_SIZE> // e.g. BLOCK_SIZE == 16 or 32.
__global__ __global__ void mutual_information_kernel(
void mutual_information_kernel( // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1 torch::PackedTensorAccessor32<scalar_t, 3> px,
torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T. torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T.
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. This is an output. // B, S + 1, T + 1. This is an output.
torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) torch::PackedTensorAccessor32<scalar_t, 3> p,
torch::PackedTensorAccessor32<scalar_t, 1> ans, // [B] // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and so on, torch::PackedTensorAccessor32<int64_t, 2> boundary,
// up to num_iters - 1 where torch::PackedTensorAccessor32<scalar_t, 1> ans, // [B]
// num_iters = num_s_blocks + num_t_blocks - 1 int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and
// num_s_blocks = S / BLOCK_SIZE + 1 // so on, up to num_iters - 1 where num_iters = num_s_blocks +
// num_t_blocks = T / BLOCK_SIZE + 1 // num_t_blocks - 1 num_s_blocks = S / BLOCK_SIZE + 1
// so that each group depends on the previous group... // num_t_blocks = T / BLOCK_SIZE + 1
const int B = px.size(0), // so that each group depends on the previous group...
S = px.size(1), const int B = px.size(0), S = px.size(1), T = py.size(2);
T = py.size(2); const bool modified = (px.size(2) == T);
const int t_offset = (modified ? -1 : 0); // see CPU code to understand.
// num_s_blocks and num_t_blocks are the number of blocks we need to cover the // num_s_blocks and num_t_blocks are the number of blocks we need to cover the
// array of size (S, T) with blocks of this size, in the s and t directions // array of size (S, T) with blocks of this size, in the s and t directions
// respectively. // respectively.
// You can read the following expressions as simplifications of, for example, // You can read the following expressions as simplifications of, for example,
// num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE, // num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE,
// i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + 1). // i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T +
// 1).
const int num_s_blocks = S / BLOCK_SIZE + 1; const int num_s_blocks = S / BLOCK_SIZE + 1;
//, num_t_blocks = T / BLOCK_SIZE + 1; //, num_t_blocks = T / BLOCK_SIZE + 1;
...@@ -134,16 +127,16 @@ void mutual_information_kernel( ...@@ -134,16 +127,16 @@ void mutual_information_kernel(
int num_blocks_this_iter = min(iter + 1, num_s_blocks); int num_blocks_this_iter = min(iter + 1, num_s_blocks);
// For the block with s_block_begin == 0 and t_block_begin == 0 (for // For the block with s_block_begin == 0 and t_block_begin == 0 (for
// easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0 // easy illustration), px_buf[s][t] will contain px[s - 1][t + t_offset]; or
// for out-of-range indexes into px. // -infinity. for out-of-range indexes into px. Likewise, py_buf[s][t] will
// Likewise, py_buf[s][t] will contain exp(py[s][t - 1]). // contain (py[s][t - 1]).
__shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE],
py_buf[BLOCK_SIZE][BLOCK_SIZE]; py_buf[BLOCK_SIZE][BLOCK_SIZE];
// p_buf[s][t] == exp(p[s+s_block_begin-1][t+t_block_begin-1] - normalizer). // p_buf[s][t] == p[s+s_block_begin-1][t+t_block_begin-1]
// 1st row/col of p_buf correspond to the previously computed blocks (lower // 1st row/col of p_buf correspond to the previously computed blocks (lower
// `iter`), or to negative indexes into p. So, for the origin block, // `iter`), or to negative indexes into p. So, for the origin block,
// p_buf[s][t] corresponds to exp(p[s - 1][t - 1] - normalizer); or 0 for // p_buf[s][t] corresponds to p[s - 1][t - 1]; or -inf for
// out-of-range values. // out-of-range values.
__shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1]; __shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1];
...@@ -165,7 +158,7 @@ void mutual_information_kernel( ...@@ -165,7 +158,7 @@ void mutual_information_kernel(
batch_block_iter < B * num_blocks_this_iter; batch_block_iter < B * num_blocks_this_iter;
batch_block_iter += gridDim.x) { batch_block_iter += gridDim.x) {
int block = batch_block_iter / B, int block = batch_block_iter / B,
b = batch_block_iter % B; // b is the index into the batch b = batch_block_iter % B; // b is the index into the batch
// Note: `block` can be no greater than `iter` because num_blocks_this_iter // Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and // <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and
...@@ -176,15 +169,13 @@ void mutual_information_kernel( ...@@ -176,15 +169,13 @@ void mutual_information_kernel(
__syncthreads(); __syncthreads();
if (boundary.size(0) != 0 && threadIdx.x < 4) if (threadIdx.x < 4)
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x]; boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads(); __syncthreads();
int s_begin = boundary_buf[0], int s_begin = boundary_buf[0], t_begin = boundary_buf[1],
t_begin = boundary_buf[1], s_end = boundary_buf[2], t_end = boundary_buf[3];
s_end = boundary_buf[2],
t_end = boundary_buf[3];
s_block_begin += s_begin; s_block_begin += s_begin;
t_block_begin += t_begin; t_block_begin += t_begin;
...@@ -200,95 +191,61 @@ void mutual_information_kernel( ...@@ -200,95 +191,61 @@ void mutual_information_kernel(
if (block_S <= 0 || block_T <= 0) if (block_S <= 0 || block_T <= 0)
continue; continue;
// Load px_buf and py_buf. We exponentiate; the assumption is that they // Load px_buf and py_buf.
// most likely won't overflow or underflow, but if they do overflow we'll
// detect it later; we'll also detect certain kinds of underflow.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin,
s = s_in_block + s_block_begin, t_off = t + t_offset;
t = t_in_block + t_block_begin;
// comparing as unsigned int makes sure the index is nonnegative. // comparing as unsigned int makes sure the index is nonnegative.
// Caution: if s_begin > 0 or t_begin > 0 we may end up loading some px and // Caution: if s_begin > 0 or t_begin > 0 we may end up loading some px
// py values that are outside the proper boundaries that we need, but // and py values that are outside the proper boundaries that we need, but
// the corresponding p_buf values will end up being 0 so this won't matter. // the corresponding p_buf values will end up being 0 so this won't
scalar_t this_px = 0.0; // matter.
if (s > s_begin && s <= s_end && t <= t_end) scalar_t this_px = -INFINITY;
this_px = exp(px[b][s - 1][t]); // Below, "&& t <= t_end" can be interpreted as:
// "&& (modified ? t_off < t_end : t_off <= t_end)
// [since px's last valid index is t_end - 1 if modified, else t_end.
if (s > s_begin && s <= s_end && t_off >= t_begin && t <= t_end)
this_px = px[b][s - 1][t_off];
px_buf[s_in_block][t_in_block] = this_px; px_buf[s_in_block][t_in_block] = this_px;
scalar_t this_py = 0.0;
scalar_t this_py = -INFINITY;
if (t > t_begin && t <= t_end && s <= s_end) if (t > t_begin && t <= t_end && s <= s_end)
this_py = exp(py[b][s][t - 1]); this_py = py[b][s][t - 1];
py_buf[s_in_block][t_in_block] = this_py; py_buf[s_in_block][t_in_block] = this_py;
} }
// Load the 1st row and 1st column of p_buf.
// Load the 1st row and 1st column of p_buf (except element[0][0] is not // This is the context from previously computed blocks of the
// needed). This is the context from previously computed blocks of the // image. Remember: p_buf[s][t] will correspond to p[s + s_block_begin -
// image. Remember: p_buf[s][t] will correspond to exp(p[s + s_block_begin - // 1][t + t_block_begin - 1]
// 1][t + t_block_begin - 1] - normalizer.
if (threadIdx.x <= BLOCK_SIZE) { if (threadIdx.x <= BLOCK_SIZE) {
// s_in_p_buf are simply the indexes into p_buf // s_in_p_buf and t_in_pbuf are simply the indexes into p_buf
int s_in_p_buf = threadIdx.x, int s_in_p_buf = threadIdx.x, t_in_p_buf = 0,
t_in_p_buf = 0,
s = s_in_p_buf + s_block_begin - 1, s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1; t = t_in_p_buf + t_block_begin - 1;
scalar_t this_p = -INFINITY; scalar_t this_p = -INFINITY;
if (s >= s_begin && s <= s_end && if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end)
t >= t_begin && t <= t_end)
this_p = p[b][s][t]; this_p = p[b][s][t];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]); */
p_buf[s_in_p_buf][t_in_p_buf] = this_p; p_buf[s_in_p_buf][t_in_p_buf] = this_p;
} else if (static_cast<unsigned int>(int(threadIdx.x) - 64) <= } else if (static_cast<unsigned int>(static_cast<int>(threadIdx.x) - 64) <=
static_cast<unsigned int>(BLOCK_SIZE)) { static_cast<unsigned int>(BLOCK_SIZE)) {
// Another warp handles the other leg. Checking as unsigned // Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE // tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
int s_in_p_buf = 0, int s_in_p_buf = 0, t_in_p_buf = static_cast<int>(threadIdx.x) - 64,
t_in_p_buf = (int)threadIdx.x - 64,
s = s_in_p_buf + s_block_begin - 1, s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1; t = t_in_p_buf + t_block_begin - 1;
scalar_t this_p = -INFINITY; scalar_t this_p = -INFINITY;
if (s >= s_begin && s <= s_end && if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end)
t >= t_begin && t <= t_end)
this_p = p[b][s][t]; this_p = p[b][s][t];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]);*/
p_buf[s_in_p_buf][t_in_p_buf] = this_p; p_buf[s_in_p_buf][t_in_p_buf] = this_p;
} }
__syncthreads(); __syncthreads();
// We read p_buf in log-space; we now subtract 'normalizer', which
// mathematically could be any finite number, to get it in a range close to
// zero, and then exponentiate. We'll do everything in non-log space, for
// speed, and later take a log before we write out the data.
scalar_t normalizer = (is_origin_block ? 0.0 :
max(p_buf[0][1], p_buf[1][0]));
__syncthreads();
// Normalize and exponentiate the edge elements of p_buf, i.e. the elements
// where at one index is 0. The [0][0] element is special; we write 0.0,
// and we'll overwrite with 1.0 if there is a panic situation due to
// overflow.
if (threadIdx.x <= BLOCK_SIZE) {
// p_buf[0][0] is never used for its normal purpose; we set it to zero
// p_buf[0][0] = 0.0; <-- for search purposes.
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
int s = threadIdx.x;
p_buf[s][0] = (s == 0 ? 0.0 :
exp(p_buf[s][0] - normalizer));
} else if (static_cast<unsigned int>(int(threadIdx.x) - 64) <
static_cast<unsigned int>(BLOCK_SIZE)) {
// if (threadidx.x - 64) >= 0 && (threadIdx.x - 64) < BLOCK_SIZE..
int t = (int)threadIdx.x - 64 + 1; // 0 < t <= BLOCK_SIZE
// this happens in a different warp so can be in parallel to the code above.
p_buf[0][t] = exp(p_buf[0][t] - normalizer);
}
__syncthreads();
// from here to the next __syncthreads(), only the 1st warp should be active // from here to the next __syncthreads(), only the 1st warp should be active
// so we shouldn't need to synchronize. (implicit within-warp // so we shouldn't need to synchronize. (implicit within-warp
// synchronization). // synchronization).
...@@ -299,19 +256,12 @@ void mutual_information_kernel( ...@@ -299,19 +256,12 @@ void mutual_information_kernel(
// to set p_buf to 1.0 = exp(0.0) if this is the "origin block", // to set p_buf to 1.0 = exp(0.0) if this is the "origin block",
// i.e. s == s_begin, t == t_begin. This corresponds to the // i.e. s == s_begin, t == t_begin. This corresponds to the
// probability of the pair of sequences of length (0, 0). // probability of the pair of sequences of length (0, 0).
p_buf[1][1] = (is_origin_block ? 1.0 : p_buf[1][1] =
p_buf[0][1] * px_buf[0][0] + (is_origin_block ? 0.0
p_buf[1][0] * py_buf[0][0]); : LogAdd(
} // px_buf has t_offset applied.
p_buf[0][1 + t_offset] + px_buf[0][0],
scalar_t p_buf_s1_t; // This is for an optimization to avoid one p_buf[1][0] + py_buf[0][0]));
// shared-memory read/write in the loop below. it
// represents p_buf[s + 1][t]; the first time we
// access this, it will be for t == 0, except for
// thread 0 when we first need it for t == 1.
if (threadIdx.x < BLOCK_SIZE) {
int s = threadIdx.x;
p_buf_s1_t = p_buf[s + 1][threadIdx.x == 0 ? 1 : 0];
} }
int s = threadIdx.x; int s = threadIdx.x;
...@@ -333,34 +283,23 @@ void mutual_information_kernel( ...@@ -333,34 +283,23 @@ void mutual_information_kernel(
static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial // p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account // row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and // the way these buffers relate to the tensors p, px and py,
// ignoring `normalizer`, code below can be interpreted as follows, // can be interpreted as follows,
// writing sbb for s_block_begin and tbb for t_block_begin: // writing sbb for s_block_begin and tbb for t_block_begin:
// //
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb], // p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] +
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1] // px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb-1] +
// py[s+sbb][t+tbb-1]
// //
// where you can see that apart from the offsets of tbb and sbb, this is // where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in // the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above. // mutual_information.py:mutual_information_recursion(); and (eq. 0)
#if 1 // above.
p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t];
// note: px_buf has t_offset applied..
/*printf("threadIdx.x = %d, i = %d, s = %d, t = %d, p_buf[s+1][t+1] = %f, p_buf[s][t+1] = %f, " p_buf[s + 1][t + 1] = LogAdd(p_buf[s][t + 1 + t_offset] + px_buf[s][t],
"px_buf[s][t] = %f, p_buf[s + 1][t] = %f, py_buf[s][t] = %f\n", p_buf[s + 1][t] + py_buf[s][t]);
(int)threadIdx.x, i, s, t, (float)p_buf[s+1][t+1], (float)p_buf[s][t+1],
(float)px_buf[s][t], (float)p_buf[s+1][t], (float)py_buf[s][t]);*/
#else
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf[s + 1][t + 1] = p_buf_s1_t;
#endif
// We don't need to do __syncthreads() in this loop because all the // We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future, // threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here). // if NVidia changes some things, we might need to sync here).
...@@ -368,21 +307,13 @@ void mutual_information_kernel( ...@@ -368,21 +307,13 @@ void mutual_information_kernel(
} }
__syncthreads(); __syncthreads();
// Write out the data to p; check that nothing has gone out of numerical // Write out the data to p;
// range, and write 'panic' flag if it has.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
if (s_in_block < block_S && t_in_block < block_T) { if (s_in_block < block_S && t_in_block < block_T) {
scalar_t this_p = p_buf[s_in_block + 1][t_in_block + 1]; scalar_t this_p = p_buf[s_in_block + 1][t_in_block + 1];
p[b][s][t] = normalizer + log(this_p); p[b][s][t] = this_p;
// If this_p is infinity or NaN..
if (this_p - this_p != 0) {
// printf("[panic] threadIdx.x = %d, this_p = %f\n", (int)threadIdx.x, (float)this_p);
p_buf[0][0] = 1.0; // This is a "panic" flag.
}
} }
} }
...@@ -397,165 +328,112 @@ void mutual_information_kernel( ...@@ -397,165 +328,112 @@ void mutual_information_kernel(
// you could read block_S below as block_S - 1 + 1, meaning, // you could read block_S below as block_S - 1 + 1, meaning,
// it's the last index in a block of size block_S, but the indexes into // it's the last index in a block of size block_S, but the indexes into
// p_buf have a "+ 1". Likewise for block_T. // p_buf have a "+ 1". Likewise for block_T.
ans[b] = normalizer + log(p_buf[block_S][block_T]); ans[b] = p_buf[block_S][block_T];
}
}
if (p_buf[0][0] != 0.0) {
/*
// FOR DEBUGGING PANIC MODE:
if (threadIdx.x == 0)
printf("Panic flag set, value = %f\n", (float)p_buf[0][0]);
*/
// The "panic" flag is set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
int s_in_block = threadIdx.x;
for (int i = 0; i < block_S + block_T - 1; ++i) {
__syncwarp();
int t_in_block = i - s_in_block;
if (static_cast<unsigned int>(t_in_block) <
static_cast<unsigned int>(block_T) &&
s_in_block < block_S) {
int s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
scalar_t p_s1 = (s == s_begin ? -INFINITY : p[b][s - 1][t]),
this_px = (s == s_begin ? -INFINITY : px[b][s - 1][t]),
p_t1 = (t == t_begin ? -INFINITY : p[b][s][t - 1]),
this_py = (t == t_begin ? -INFINITY : py[b][s][t - 1]);
scalar_t this_p = LogAdd(p_s1 + this_px,
p_t1 + this_py);
if (i == 0 && is_origin_block)
this_p = 0.0;
p[b][s][t] = this_p;
}
}
__syncwarp();
if (threadIdx.x == 0) {
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
if (s_block_begin + block_S - 1 == s_end &&
t_block_begin + block_T - 1 == t_end)
ans[b] = p[b][s_end][t_end];
} }
} }
} }
} }
// like exp(), but returns 0 if arg is inf/nan, or if result would be
// infinity or nan (note: this can happen for out-of-range elements
// when setting px_buf and py_buf is block_S != BLOCK_SIZE or
// block_T != BLOCK_SIZE, and it's a problem because even though
// out-of-range gradients are zero, if we multiply them by infinity
// we get NaN.
template <typename Real> __forceinline__ __device__ Real safe_exp(Real x) {
if (x - x != 0)
return 0;
else {
Real ans = exp(x);
if (ans - ans != 0.0)
return 0;
return ans;
}
}
/* /*
Backward of mutual_information. Backward of mutual_information.
If we were to write the forward pass in non-log space, it would be (ignoring The forward pass is:
edge cases), as follows... we'll prefix all the variable names with e, e.g. ep,
to clarify that it's the exp of the actual argument p:
ep[b][s][t] = ep[b][s - 1][t] * epx[b][s - 1][t] +
ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1)
(A)
First we consider the part of the backprop that requires recursion or iteration,
i.e. the part involving only gradients of ep. This is:
ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t]
ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1].
.. and if we add 1 to the s index of the first equation above and 1 to the
t index of the second equation, we can see that:
ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] +
ep_grad[b][s][t + 1] * epy[b][s][t].
Now, if ep = exp(p), and y is the loss function we are backprop'ing,
then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
== dy/dp / ep. == p_grad / ep.
I.e. ep_grad = p_grad / ep.
So we can write the above as:
p_grad[b][s][t] / ep[b][s][t] = p_grad[b][s + 1][t] / ep[b][s + 1][t] * epx[b][s][t] +
p_grad[b][s][t + 1] / ep[b][s][t + 1] * epy[b][s][t].
Or, rearranging: p[b,s,t] = log_add(p[b,s-1,t+t_offset] + px[b,s-1,t+t_offset],
p_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) + p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]). (eq. 2)
(B) The following is the backprop for epx and epy from (eq. 1):
epx_grad[b][s - 1][t] += ep_grad[b][s][t] * ep[b][s - 1][t]
epy_grad[b][s][t - 1] += ep_grad[b][s][t] * ep[b][s][t - 1]
.. adding 1 to the s indexes in the 1st equation and to the t indexes in the 2nd:
epx_grad[b][s][t] = ep_grad[b][s + 1][t] * ep[b][s][t] where t_offset = (modified ? -1 : 0)
epy_grad[b][s][t] = ep_grad[b][s][t + 1] * ep[b][s][t]
Using, similar to the above, ep_grad = p_grad / ep, and similarly, The backprop for the above, implemented in the obvious way, would be as
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on, follows (note, we define term1 and term2 with offsets in the indexes, which
the above becomes: will be convenient later..):
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t]) term1(b,s-1,t+t_offset) =
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t]) exp(p[b,s-1,t+t_offset] + px[b,s-1,t+t_offset] - p[b,s,t]) (0a)
Rearranging: term2(b,s,t-1) = exp(p[b,s,t-1] + py[b,s,t-1] - p[b,s,t]) (0b)
px_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 3a)
py_grad[b][s][t] = p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 3b)
p_grad[b,s-1,t+t_offset] += p_grad[b,s,t] * term1(b,s-1,t+t_offset) (1a)
px_grad[b,s-1,t+t_offset] += p_grad[b,s,t] * term1(b,s-1,t+t_offset) (1b)
p_grad[b,s,t-1] += p_grad[b,s,t] * term2(b,s,t-1) (1c)
py_grad[b,s,t-1] += p_grad[b,s,t] * term2(b,s,t-1) (1d)
Defining terms that are common to (eq. 2) and (eqs. 3a,3b), write: Adding 1 and -t_offset to the s and t indexes of (1a) an (1b), and
1 to the t index of (1c) and (1d), the equations become:
xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 4) p_grad[b,s,t] += p_grad[b,s+1,t-t_offset] * term1(b,s,t) (2a)
yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5) px_grad[b,s,t] += p_grad[b,s+1,t-t_offset] * term1(b,s,t) (2b)
p_grad[b,s,t] += p_grad[b,s,t+1] * term2(b,s,t) (2c)
py_grad[b,s,t] += p_grad[b,s,t+1] * term2(b,s,t) (2d)
.. and note that these quantities are <= 1 so there is no problem doing .. and replacing "+=" with "=", we can write:
the exponentiation. So the recursion can be simplified as from eqs. (2, 3a, 3b), as:
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + p_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t) + (3a)
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6) p_grad[b,s,t+1] * term2(b,s,t)
px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] (eq. 7) px_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t) (3b)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8) py_grad[b,s,t] = p_grad[b,s,t+1] * term2(b,s,t) (3c)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's Writing the definitions of term1 and term2 in a more convenient way:
not clear to me that this is the best strategy since that would require an extra term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a)
write to shared memory within the loop that's the limiting factor.) term2(b,s,t) = exp(p[b,s,t] + py[b,s,t] - p[b,s,t+1]) (4b)
The backward pass will be slightly different from the forward pass in terms of The backward pass will be slightly different from the forward pass in terms of
how we store and index p (and p_grad), because for writing a particular block how we store and index p (and p_grad), because for writing a particular block
of p_grad, we need context on the top and right instead of the bottom and of p_grad, we need context on the top and right instead of the bottom and
left. So there are offsets of 1. left. So there are offsets of 1.
*/ */
template <typename scalar_t, template <typename scalar_t, int BLOCK_SIZE>
int BLOCK_SIZE> __global__ void mutual_information_backward_kernel(
__global__ torch::PackedTensorAccessor32<scalar_t, 3>
void mutual_information_backward_kernel( px, // B, S, T + 1 if !modified; B, S, T if modified.
torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1 torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T.
torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T. // B, S + 1, T + 1. Produced in forward pass.
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. Produced in forward pass. torch::PackedTensorAccessor32<scalar_t, 3> p,
torch::PackedTensorAccessor32<scalar_t, 1> ans_grad, // [B]. This is an input. // [B]. This is an input.
torch::PackedTensorAccessor32<scalar_t, 3> p_grad, // B, S + 1, T + 1. This is a temporary. torch::PackedTensorAccessor32<scalar_t, 1> ans_grad,
torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1. torch::PackedTensorAccessor32<scalar_t, 3>
torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T. p_grad, // B, S + 1, T + 1 if !modified; B, S, T if modified.
torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1.
int iter, // This kernel is sequentially called with 'iter' = num_iters torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T.
// - 1, num_iters - 2, .. 0, where num_iters can be taken to // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
// be any sufficiently large number but will actually be: torch::PackedTensorAccessor32<int64_t, 2> boundary,
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S / int iter, // This kernel is sequentially called with 'iter' = num_iters
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1 // - 1, num_iters - 2, .. 0, where num_iters can be taken to
bool overwrite_ans_grad) { // If overwite_ans_grad == true, this function // be any sufficiently large number but will actually be:
// will overwrite ans_grad with a value which, // num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// if everything is working correctly, should be // BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
// identical or very close to the value of bool overwrite_ans_grad) { // If overwite_ans_grad == true, this function
// ans_grad that was passed in. // will overwrite ans_grad with a value which,
const int B = px.size(0), // if everything is working correctly, should be
S = px.size(1), // identical or very close to the value of
T = py.size(2); // ans_grad that was passed in.
const int B = px.size(0), S = px.size(1), T = py.size(2);
const bool modified = (px.size(2) == T);
const int neg_t_offset = (modified ? 1 : 0);
// For statements that are the same as the forward pass, we are omitting some // For statements that are the same as the forward pass, we are omitting some
// comments. We'll focus, in the comments, on differences from the forward // comments. We'll focus, in the comments, on differences from the forward
// pass. // pass.
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
// num_t_blocks = T / BLOCK_SIZE + 1, // num_t_blocks = T / BLOCK_SIZE + 1,
num_blocks_this_iter = min(iter + 1, num_s_blocks); num_blocks_this_iter = min(iter + 1, num_s_blocks);
// px_buf and py_buf are used temporarily to store the px and py values, // px_buf and py_buf are used temporarily to store the px and py values,
// but then modified to store the "xderiv" and "yderiv" values defined // but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0 // in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
...@@ -564,15 +442,17 @@ void mutual_information_backward_kernel( ...@@ -564,15 +442,17 @@ void mutual_information_backward_kernel(
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin]; // px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin]. // py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// Later (see eq. 4 and eq. 5): // Later (see eq. 4 and eq. 5):
// px_buf[s][t] contains exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt]), // px_buf[s][t] contains term1(b,ss,tt) ==
// py_buf[s][t] contains exp(p[b][ss][tt] + py[b][ss][tt] - p[b][ss][tt + 1] // exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt-t_offset]),
// py_buf[s][t] contains term2(b,ss,tt) ==
// where ss == s + s_block_begin, tt = t + t_block_begin. // where ss == s + s_block_begin, tt = t + t_block_begin.
// Unlike in the forward code, there is no offset of 1 in the indexes. // Unlike in the forward code, there is no offset of 1 in the indexes.
__shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE],
py_buf[BLOCK_SIZE][BLOCK_SIZE]; py_buf[BLOCK_SIZE][BLOCK_SIZE];
// p_buf is initially used to store p, and then (after we are done putting // p_buf is initially used to store p, and then (after we are done putting
// xderiv and yderiv into px_buf and py_buf) it is repurposed to store // term1 and term2 into px_buf and py_buf) it is repurposed to store
// p_grad. // p_grad.
// //
// Unlike in the forward pass, p_buf has the same numbering as px_buf and // Unlike in the forward pass, p_buf has the same numbering as px_buf and
...@@ -603,19 +483,16 @@ void mutual_information_backward_kernel( ...@@ -603,19 +483,16 @@ void mutual_information_backward_kernel(
for (int batch_block_iter = blockIdx.x; for (int batch_block_iter = blockIdx.x;
batch_block_iter < B * num_blocks_this_iter; batch_block_iter < B * num_blocks_this_iter;
batch_block_iter += gridDim.x) { batch_block_iter += gridDim.x) {
int block = batch_block_iter / B, int block = batch_block_iter / B, b = batch_block_iter % B;
b = batch_block_iter % B;
int s_block_begin = block * BLOCK_SIZE, int s_block_begin = block * BLOCK_SIZE,
t_block_begin = (iter - block) * BLOCK_SIZE; t_block_begin = (iter - block) * BLOCK_SIZE;
if (threadIdx.x < 4 && boundary.size(0) != 0) if (threadIdx.x < 4)
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x]; boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads(); __syncthreads();
int s_begin = boundary_buf[0], int s_begin = boundary_buf[0], t_begin = boundary_buf[1],
t_begin = boundary_buf[1], s_end = boundary_buf[2], t_end = boundary_buf[3];
s_end = boundary_buf[2],
t_end = boundary_buf[3];
s_block_begin += s_begin; s_block_begin += s_begin;
t_block_begin += t_begin; t_block_begin += t_begin;
...@@ -633,13 +510,11 @@ void mutual_information_backward_kernel( ...@@ -633,13 +510,11 @@ void mutual_information_backward_kernel(
// Load px_buf and py_buf. At this point we just set them to the px and py // Load px_buf and py_buf. At this point we just set them to the px and py
// for this block. // for this block.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
s = s_in_block + s_block_begin, // We let px and py default to -infinity if they are out of range, which
t = t_in_block + t_block_begin; // will cause xderiv and yderiv for out-of-range values to be zero, and
// We let px and py default to -infinity if they are out of range, which will // cause correct behavior in edge cases (for the top and right blocks).
// cause xderiv and yderiv for out-of-range values to be zero, and cause
// correct behavior in edge cases (for the top and right blocks).
// The issue is that p and p_grad are of larger size than px and py. // The issue is that p and p_grad are of larger size than px and py.
scalar_t this_px = -INFINITY; scalar_t this_px = -INFINITY;
if (s < s_end && t <= t_end) if (s < s_end && t <= t_end)
...@@ -653,11 +528,10 @@ void mutual_information_backward_kernel( ...@@ -653,11 +528,10 @@ void mutual_information_backward_kernel(
__syncthreads(); __syncthreads();
// load p. // load p.
for (int i = threadIdx.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1); i += blockDim.x) { for (int i = threadIdx.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1);
int s_in_block = i / (BLOCK_SIZE + 1), i += blockDim.x) {
t_in_block = i % (BLOCK_SIZE + 1), int s_in_block = i / (BLOCK_SIZE + 1), t_in_block = i % (BLOCK_SIZE + 1),
s = s_in_block + s_block_begin, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
t = t_in_block + t_block_begin;
// Setting 0.0 for out-of-bounds elements of p, together with setting // Setting 0.0 for out-of-bounds elements of p, together with setting
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will // -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// ensure that we do the right thing in top and right edge cases, // ensure that we do the right thing in top and right edge cases,
...@@ -666,56 +540,57 @@ void mutual_information_backward_kernel( ...@@ -666,56 +540,57 @@ void mutual_information_backward_kernel(
scalar_t this_p = 0.0; scalar_t this_p = 0.0;
if (s <= s_end && t <= t_end) if (s <= s_end && t <= t_end)
this_p = p[b][s][t]; this_p = p[b][s][t];
// if this_p is -inf, replace with large finite negative value, to avoid
// NaN's below.
// TODO: use a value that would work correctly in half precision
if (this_p < -1.0e+30)
this_p = -1.0e+30;
p_buf[s_in_block][t_in_block] = this_p; p_buf[s_in_block][t_in_block] = this_p;
} }
__syncthreads(); __syncthreads();
// Set xderiv and yderiv; see (eq. 4) and (eq. 5). // Set term1 and term2; see equations (4a) and (4b) above.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
// We can apply this formula to the entire block even if we are processing // We can apply this formula to the entire block even if we are processing
// a partial block; we have ensured that x_buf and y_buf contain -infinity, // a partial block; we have ensured that x_buf and y_buf contain
// and p contains 0, for out-of-range elements, so we'll get x_buf and y_buf // -infinity, and p contains 0, for out-of-range elements, so we'll get
// containing 0 after applying the followin formulas. // x_buf and y_buf containing 0 after applying the followin formulas.
int s = i / BLOCK_SIZE, int s = i / BLOCK_SIZE, t = i % BLOCK_SIZE;
t = i % BLOCK_SIZE;
// Mathematically the following is doing: // Mathematically the following is doing:
// xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) // term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a)
// (with an offset on the s and t indexes) // (with an offset on the s and t indexes)
px_buf[s][t] = exp(p_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t]); // Use safe_exp() not exp(), as we could have (-inf) - (-inf) = nan, want
// any finite number in this case as derivs would be zero.
// Also want -inf->zero.
px_buf[s][t] =
safe_exp(p_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t + neg_t_offset]);
// Mathematically the following is doing: // Mathematically the following is doing:
// yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) // term2(b,s,t) = exp(p[b,s,t] + py[b,s,t] - p[b,s,t+1]) (4b)
// (with an offset on the s and t indexes) // (with an offset on the s and t indexes)
py_buf[s][t] = exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]); py_buf[s][t] = safe_exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]);
} }
__syncthreads(); __syncthreads();
// Load p_grad for the top and right elements in p_buf: i.e. for elements // Load p_grad for the top and right elements in p_buf: i.e. for elements
// p_buf[s][t] where s == block_S (exclusive-or) t == block_T. We don't // p_buf[s][t] where s == block_S (exclusive-or) t == block_T.
// need to load the top-right corner [block_S][block_T]; that location will
// never be accessed.
// These are the p_grad values computed by previous instances of this kernel // These are the p_grad values computed by previous instances of this kernel
// If this is one of the top or right blocks, some or all of the p_grad // If this is one of the top or right blocks, some or all of the p_grad
// values we'd be reading here will be out of range, and we use zeros // values we'd be reading here will be out of range, and we use zeros
// to ensure no gradient gets propagated from those positions. // to ensure no gradient gets propagated from those positions.
if (threadIdx.x < block_S) { if (threadIdx.x <= block_S) {
int s_in_block = threadIdx.x, int s_in_block = threadIdx.x, t_in_block = block_T,
t_in_block = block_T, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
s = s_in_block + s_block_begin, p_buf[s_in_block][t_in_block] =
t = t_in_block + t_block_begin; (s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
p_buf[s_in_block][t_in_block] = ( } else if (static_cast<unsigned int>(static_cast<int>(threadIdx.x) - 64) <
s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
} else if (static_cast<unsigned int>((int)threadIdx.x - 64) <
static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(block_T)) {
// casting to unsigned before the comparison tests for both negative and // casting to unsigned before the comparison tests for both negative and
// out-of-range values of (int)threadIdx.x - 64. // out-of-range values of (int)threadIdx.x - 64.
int s_in_block = block_S, int s_in_block = block_S, t_in_block = static_cast<int>(threadIdx.x) - 64,
t_in_block = (int)threadIdx.x - 64, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
s = s_in_block + s_block_begin, p_buf[s_in_block][t_in_block] =
t = t_in_block + t_block_begin; (s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
} }
__syncthreads(); __syncthreads();
...@@ -748,10 +623,11 @@ void mutual_information_backward_kernel( ...@@ -748,10 +623,11 @@ void mutual_information_backward_kernel(
static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
// The following statement is really operating on the gradients; // The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin // it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.: // on the indexes, to equation (3a) above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + // p_grad[b,s,t] =
// p_grad[b][s][t + 1] * yderiv[b][s][t] // p_grad[b,s+1,t-t_offset] * term1(b,s,t) + (3a)
p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] + // p_grad[b,s,t+1] * term2(b,s,t)
p_buf[s][t] = (p_buf[s + 1][t + neg_t_offset] * px_buf[s][t] +
p_buf[s][t + 1] * py_buf[s][t]); p_buf[s][t + 1] * py_buf[s][t]);
} }
} }
...@@ -761,24 +637,27 @@ void mutual_information_backward_kernel( ...@@ -761,24 +637,27 @@ void mutual_information_backward_kernel(
// Write out p_grad, px_grad and py_grad. // Write out p_grad, px_grad and py_grad.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// s_end and t_end are the one-past-the-end of the (x,y) sequences, but // s_end and t_end are the one-past-the-end of the (x,y) sequences, but
// the one-past-the-end element of p_grad would be (s_end + 1, t_end + 1). // the one-past-the-end element of p_grad would be (s_end + 1, t_end + 1).
if (t <= t_end && s <= s_end) { if (t <= t_end && s <= s_end) {
p_grad[b][s][t] = p_buf[s_in_block][t_in_block]; p_grad[b][s][t] = p_buf[s_in_block][t_in_block];
if (s < s_end) { // write px_grad, which is of shape [B][S][T + 1] if (s < s_end && t <= t_end - neg_t_offset) {
// From (eq. 7): // write px_grad, which is of shape [B][S][T + 1] if !modified,
// px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] // [B][S][T] if modified. the condition "t <= t_end - neg_t_offset"
px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block] * // becomes "t <= t_end" if !modified, and "t <= t_end - 1" if
// modified, keeping us within the bounds of px_grad.
// From (eq. 3b):
// px_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t)
px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block + neg_t_offset] *
px_buf[s_in_block][t_in_block]); px_buf[s_in_block][t_in_block]);
} }
if (t < t_end) { // write py_grad, which is of shape [B][S + 1][T] if (t < t_end) { // write py_grad, which is of shape [B][S + 1][T]
// from (eq. 8): // from (eq. 3c):
// py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] // py_grad[b,s,t] = p_grad[b,s,t+1] * term2(b,s,t)
py_grad[b][s][t] = (p_buf[s_in_block][t_in_block + 1] * py_grad[b][s][t] = (p_buf[s_in_block][t_in_block + 1] *
py_buf[s_in_block][t_in_block]); py_buf[s_in_block][t_in_block]);
} }
...@@ -791,81 +670,77 @@ void mutual_information_backward_kernel( ...@@ -791,81 +670,77 @@ void mutual_information_backward_kernel(
} }
} }
// forward of mutual_information. See """... """ comment of
// `mutual_information` in mutual_information.py for documentation of the
// forward of mutual_information. See """... """ comment of `mutual_information` in // behavior of this function.
// mutual_information.py for documentation of the behavior of this function. torch::Tensor MutualInformationCuda(torch::Tensor px, torch::Tensor py,
torch::Tensor mutual_information_cuda(torch::Tensor px, torch::optional<torch::Tensor> opt_boundary,
torch::Tensor py, torch::Tensor p) {
torch::Tensor boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() && p.device().is_cuda(), TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() &&
p.device().is_cuda(),
"inputs must be CUDA tensors"); "inputs must be CUDA tensors");
auto scalar_t = px.scalar_type(); auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), const int B = px.size(0), S = px.size(1), T = py.size(2);
S = px.size(1), TORCH_CHECK(px.size(2) == T || px.size(2) == T + 1);
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4)); auto boundary = opt_boundary.value_or(
TORCH_CHECK(boundary.device().is_cuda() && torch::tensor({0, 0, S, T},
boundary.dtype() == torch::kInt64); torch::dtype(torch::kInt64).device(px.device()))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cuda() && boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts); torch::Tensor ans = torch::empty({B}, opts);
// num_threads and num_blocks and BLOCK_SIZE can be tuned. // num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128). // (however, num_threads may not be less than 128).
const int num_threads = 128, const int num_threads = 128, num_blocks = 256, BLOCK_SIZE = 32;
num_blocks = 256,
BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1), // The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g. // so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1 // (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cuda_stub", ([&] { AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cuda_stub", ([&] {
for (int iter = 0; iter < num_iters; ++iter) { for (int iter = 0; iter < num_iters; ++iter) {
mutual_information_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>( mutual_information_kernel<scalar_t, BLOCK_SIZE>
px.packed_accessor32<scalar_t, 3>(), <<<num_blocks, num_threads>>>(
py.packed_accessor32<scalar_t, 3>(), px.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(), py.packed_accessor32<scalar_t, 3>(),
boundary.packed_accessor32<int64_t, 2>(), p.packed_accessor32<scalar_t, 3>(),
ans.packed_accessor32<scalar_t, 1>(), boundary.packed_accessor32<int64_t, 2>(),
iter); ans.packed_accessor32<scalar_t, 1>(), iter);
} }
})); }));
return ans; return ans;
} }
// backward of mutual_information; returns (grad_px, grad_py) // backward of mutual_information; returns (grad_px, grad_py)
// If overwrite_ans_grad == true, will overwrite ans_grad with a value which // If overwrite_ans_grad == true, will overwrite ans_grad with a value which
// should be identical to the original ans_grad if the computation worked // should be identical to the original ans_grad if the computation worked
// as it should. // as it should.
std::vector<torch::Tensor> std::vector<torch::Tensor>
mutual_information_backward_cuda(torch::Tensor px, MutualInformationBackwardCuda(torch::Tensor px, torch::Tensor py,
torch::Tensor py, torch::optional<torch::Tensor> opt_boundary,
torch::Tensor boundary, torch::Tensor p, torch::Tensor ans_grad,
torch::Tensor p, bool overwrite_ans_grad) {
torch::Tensor ans_grad,
bool overwrite_ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional."); TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional.");
TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() && TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() &&
p.device().is_cuda() && ans_grad.device().is_cuda() && p.device().is_cuda() && ans_grad.device().is_cuda() &&
"inputs must be CUDA tensors"); "inputs must be CUDA tensors");
...@@ -873,55 +748,59 @@ mutual_information_backward_cuda(torch::Tensor px, ...@@ -873,55 +748,59 @@ mutual_information_backward_cuda(torch::Tensor px,
auto scalar_t = px.scalar_type(); auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), const int B = px.size(0), S = px.size(1), T = py.size(2);
S = px.size(1),
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(px.size(2) == T ||
px.size(2) == T + 1); // modified case || not-modified case
const bool modified = (px.size(2) == T);
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4)); auto boundary = opt_boundary.value_or(
TORCH_CHECK(boundary.device().is_cuda() && torch::tensor({0, 0, S, T},
boundary.dtype() == torch::kInt64); torch::dtype(torch::kInt64).device(px.device()))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cuda() && boundary.dtype() == torch::kInt64);
TORCH_CHECK(ans_grad.size(0) == B); TORCH_CHECK(ans_grad.size(0) == B);
bool has_boundary = (boundary.size(0) != 0); bool has_boundary = opt_boundary.has_value();
int T1 = T + (modified ? 0 : 1);
torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts), torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) : px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts)
torch::empty({B, S, T + 1}, opts)), : torch::empty({B, S, T1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) : py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts)
torch::empty({B, S + 1, T}, opts)); : torch::empty({B, S + 1, T}, opts));
// num_threads and num_blocks and BLOCK_SIZE can be tuned. // num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128). // (however, num_threads may not be less than 128).
const int num_threads = 128, const int num_threads = 128, num_blocks = 256, BLOCK_SIZE = 32;
num_blocks = 256,
BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1), // The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g. // so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1 // (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_backward_stub", ([&] { AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_backward_stub", ([&] {
for (int iter = num_iters - 1; iter >= 0; --iter) { for (int iter = num_iters - 1; iter >= 0; --iter) {
mutual_information_backward_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>( mutual_information_backward_kernel<scalar_t, BLOCK_SIZE>
px.packed_accessor32<scalar_t, 3>(), <<<num_blocks, num_threads>>>(
py.packed_accessor32<scalar_t, 3>(), px.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(), py.packed_accessor32<scalar_t, 3>(),
ans_grad.packed_accessor32<scalar_t, 1>(), p.packed_accessor32<scalar_t, 3>(),
p_grad.packed_accessor32<scalar_t, 3>(), ans_grad.packed_accessor32<scalar_t, 1>(),
px_grad.packed_accessor32<scalar_t, 3>(), p_grad.packed_accessor32<scalar_t, 3>(),
py_grad.packed_accessor32<scalar_t, 3>(), px_grad.packed_accessor32<scalar_t, 3>(),
boundary.packed_accessor32<int64_t, 2>(), py_grad.packed_accessor32<scalar_t, 3>(),
iter, boundary.packed_accessor32<int64_t, 2>(), iter,
overwrite_ans_grad); overwrite_ans_grad);
} }
})); }));
// std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad}); return std::vector<torch::Tensor>({px_grad, py_grad});
} }
} // namespace fast_rnnt
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/utils.h"
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src) {
TORCH_CHECK(src.dim() == 1, "Only support one dimension tensor");
TORCH_CHECK(src.scalar_type() == torch::kLong, "Only support LongTensor");
TORCH_CHECK(src.is_contiguous(), "Expected to be contiguous");
int32_t dim = src.numel();
if (src.device().type() == torch::kCPU) {
int64_t min_value = std::numeric_limits<int64_t>::max();
int64_t *src_data = src.data_ptr<int64_t>();
for (int32_t i = dim - 1; i >= 0; --i) {
min_value = std::min(src_data[i], min_value);
src[i] = min_value;
}
} else {
#ifdef FT_WITH_CUDA
TORCH_CHECK(src.device().is_cuda());
internal::MinOp<int64_t> min_op;
auto src_data = src.data_ptr<int64_t>();
internal::ConstReversedPtr<int64_t> src_ptr =
internal::ConstReversedPtr<int64_t>(src_data, dim);
internal::ReversedPtr<int64_t> dest_ptr =
internal::ReversedPtr<int64_t>(src_data, dim);
// The first time is to determine temporary device storage requirements.
std::size_t temp_storage_bytes = 0;
auto s = cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
auto d_temp = torch::empty({static_cast<int64_t>(temp_storage_bytes)},
torch::dtype(torch::kInt8).device(src.device()));
int8_t *d_temp_storage = d_temp.data_ptr<int8_t>();
s = cub::DeviceScan::InclusiveScan(d_temp_storage, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
TORCH_CHECK(false, "Please build with -DFT_WITH_CUDA=ON");
#endif // FT_WITH_CUDA
}
}
} // namespace fast_rnnt
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_UTILS_H_
#define FAST_RNNT_CSRC_UTILS_H_
#include "torch/extension.h"
#ifdef FT_WITH_CUDA
#include "cub/cub.cuh" // NOLINT
namespace fast_rnnt {
namespace internal {
template <typename T> struct MinOp {
__host__ __device__ __forceinline__ T operator()(const T &a,
const T &b) const {
return (a > b) ? b : a;
}
};
// Will be used (as both InputIterator and OutputIterator) in
// MonotonicLowerBound to call cub::DeviceScan::InclusiveScan.
template <typename T> struct ConstReversedPtr {
const T *data;
// data points to the last element now
explicit ConstReversedPtr(const T *data, int32_t size)
: data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ const T &operator[](int32_t i) const {
return data[-i];
}
__host__ __device__ __forceinline__ ConstReversedPtr
operator+(int32_t n) const {
ConstReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ const T &operator*() const {
return *data;
}
};
template <typename T> struct ReversedPtr {
T *data;
// data points to the last element now
explicit ReversedPtr(T *data, int32_t size) : data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ T &operator[](int32_t i) {
return data[-i];
}
__host__ __device__ __forceinline__ ReversedPtr operator+(int32_t n) const {
ReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ T &operator*() { return *data; }
};
} // namespace internal
} // namespace fast_rnnt
namespace std {
// vaule_type is required by cub::DeviceScan::InclusiveSum
template <typename T>
struct iterator_traits<fast_rnnt::internal::ConstReversedPtr<T>> {
typedef T value_type;
};
template <typename T>
struct iterator_traits<fast_rnnt::internal::ReversedPtr<T>> {
typedef T value_type;
};
} // namespace std
#endif // FT_WITH_CUDA
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_UTILS_H_
add_subdirectory(csrc)
add_subdirectory(tests)
include_directories(${CMAKE_SOURCE_DIR})
include(transform)
# please keep the list sorted
set(fast_rnnt_srcs
fast_rnnt.cu
mutual_information.cu
utils.cu
)
if(NOT FT_WITH_CUDA)
transform(OUTPUT_VARIABLE fast_rnnt_srcs SRCS ${fast_rnnt_srcs})
endif()
pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs})
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)
if(UNIX AND NOT APPLE)
target_link_libraries(_fast_rnnt
PRIVATE
${PYTHON_LIBRARY}
${TORCH_DIR}/lib/libtorch_python.so
)
endif()
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/utils.h"
PYBIND11_MODULE(_fast_rnnt, m) {
m.doc() = "Python wrapper for Fast Rnnt.";
fast_rnnt::PybindMutualInformation(m);
fast_rnnt::PybindUtils(m);
}
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
#define FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
#include "pybind11/pybind11.h"
namespace py = pybind11;
#endif // FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
namespace fast_rnnt {
void PybindMutualInformation(py::module &m) {
m.def(
"mutual_information_forward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p) -> torch::Tensor {
fast_rnnt::DeviceGuard guard(px.device());
if (px.device().is_cpu()) {
return MutualInformationCpu(px, py, boundary, p);
} else {
#ifdef FT_WITH_CUDA
return MutualInformationCuda(px, py, boundary, p);
#else
TORCH_CHECK(false, "Failed to find native CUDA module, make sure "
"that you compiled the code with K2_WITH_CUDA.");
return torch::Tensor();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"));
m.def(
"mutual_information_backward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary, torch::Tensor p,
torch::Tensor ans_grad) -> std::vector<torch::Tensor> {
fast_rnnt::DeviceGuard guard(px.device());
if (px.device().is_cpu()) {
return MutualInformationBackwardCpu(px, py, boundary, p, ans_grad);
} else {
#ifdef FT_WITH_CUDA
return MutualInformationBackwardCuda(px, py, boundary, p, ans_grad,
true);
#else
TORCH_CHECK(false, "Failed to find native CUDA module, make sure "
"that you compiled the code with K2_WITH_CUDA.");
return std::vector<torch::Tensor>();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"),
py::arg("ans_grad"));
}
} // namespace fast_rnnt
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace fast_rnnt {
void PybindMutualInformation(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/utils.h"
#include "fast_rnnt/python/csrc/utils.h"
namespace fast_rnnt {
void PybindUtils(py::module &m) {
m.def(
"monotonic_lower_bound_",
[](torch::Tensor &src) -> void {
DeviceGuard guard(src.device());
if (src.dim() == 1) {
MonotonicLowerBound(src);
} else if (src.dim() == 2) {
int32_t dim0 = src.sizes()[0];
for (int32_t i = 0; i < dim0; ++i) {
auto sub = src.index({i});
MonotonicLowerBound(sub);
}
} else {
TORCH_CHECK(false,
"Only support 1 dimension and 2 dimensions tensor");
}
},
py::arg("src"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
}
} // namespace fast_rnnt
/**
* @brief python wrappers for utils.h
*
* @copyright
* Copyright 2022 Xiaomi Corp. (author: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_UTILS_H_
#define FAST_RNNT_PYTHON_CSRC_UTILS_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace fast_rnnt {
void PybindUtils(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_UTILS_H_
from _fast_rnnt import monotonic_lower_bound_
from _fast_rnnt import with_cuda
from .mutual_information import mutual_information_recursion
from .mutual_information import joint_mutual_information_recursion
from .rnnt_loss import do_rnnt_pruning
from .rnnt_loss import get_rnnt_logprobs
from .rnnt_loss import get_rnnt_logprobs_joint
from .rnnt_loss import get_rnnt_logprobs_pruned
from .rnnt_loss import get_rnnt_logprobs_smoothed
from .rnnt_loss import get_rnnt_prune_ranges
from .rnnt_loss import rnnt_loss
from .rnnt_loss import rnnt_loss_pruned
from .rnnt_loss import rnnt_loss_simple
from .rnnt_loss import rnnt_loss_smoothed
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey, Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import _fast_rnnt
from torch import Tensor
from typing import Tuple, Optional, Sequence, Union, List
class MutualInformationRecursionFunction(torch.autograd.Function):
"""A recursion that is useful in computing mutual information between two
sequences of real vectors, but may be useful more generally in
sequence-to-sequence tasks where monotonic alignment between pairs of
sequences is desired.
"""
@staticmethod
def forward(
ctx,
px: torch.Tensor,
py: torch.Tensor,
pxy_grads: List[Optional[torch.Tensor]],
boundary: Optional[torch.Tensor] = None,
return_grad: bool = False,
) -> torch.Tensor:
"""
Computing mutual information between two sequences of real vectors.
Args:
px:
A torch.Tensor of some floating point type, with shape
``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the
length of the ``x`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols), and ``S`` is the
length of the ``y`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols). In the mutual
information application, ``px[b][s][t]`` would represent the
following log odds ratio; ignoring the b index on the right
to make the notation more
compact::
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an ``x`` value as opposed to a ``y`` value. In
practice it might be computed as ``a + b``, where ``a`` is the log
probability of choosing to extend the sequence of length ``(s,t)``
with an ``x`` as opposed to a ``y`` value; and ``b`` might in
practice be of the form::
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where ``N`` is the number of terms that the sum over ``t'``
included, which might include some or all of the other sequences as
well as this one.
Note:
we don't require ``px`` and py to be contiguous, but the
code assumes for optimization purposes that the ``T`` axis has
stride 1.
py:
A torch.Tensor of the same dtype as ``px``, with shape
``[B][S+1][T]``, representing::
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat ``x`` and ``y`` differently; the only
difference is that for optimization purposes we assume the last axis
(the ``t`` axis) has stride of 1; this is true if ``px`` and ``py``
are contiguous.
pxy_grads:
A List to store the return grads of ``px`` and ``py``
if return_grad == True.
Remain unchanged if return_grad == False.
See `this PR <https://github.com/k2-fsa/k2/pull/924>` for more
information about why we add this parameter.
Note:
the length of the list must be 2, where the first element
represents the grads of ``px`` and the second one represents
the grads of ``py``.
boundary:
If supplied, a torch.LongTensor of shape ``[B][4]``, where each
row contains ``[s_begin, t_begin, s_end, t_end]``,
with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T``
(this implies that empty sequences are allowed).
If not supplied, the values ``[0, 0, S, T]`` will be assumed.
These are the beginning and one-past-the-last positions in the
``x`` and ``y`` sequences respectively, and can be used if not
all sequences are
of the same length.
return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing
for the occupation probability is the output of the backward with a
``fake gradient`` the ``fake gradient`` is the same as the gradient
you'd get if you did
``torch.autograd.grad((scores.sum()), [px, py])``.
This is useful to implement the pruned version of rnnt loss.
Returns:
Returns a torch.Tensor of shape ``[B]``, containing the log of
the mutual information between the b'th pair of sequences. This is
defined by the following recursion on ``p[b,s,t]`` (where ``p``
is of shape ``[B,S+1,T+1]``), representing a mutual information
between sub-sequences of lengths ``s`` and ``t``::
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
where we handle edge cases by treating quantities with negative
indexes as **-infinity**. The extension to cases where the
boundaries are specified should be obvious; it just works on
shorter sequences with offsets into ``px`` and ``py``.
"""
(B, S, T1) = px.shape
T = py.shape[-1]
assert T1 in [T, T + 1]
assert py.shape == (B, S + 1, T)
if boundary is not None:
assert boundary.shape == (B, 4)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# the mutual information of the pair of subsequences of x and y that
# are of length s and t respectively. p[0][0] will be 0.0 and p[S][T]
# is the mutual information of the entire pair of sequences,
# i.e. of lengths S and T respectively.
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _fast_rnnt.mutual_information_forward(px, py, boundary, p)
px_grad, py_grad = None, None
if return_grad or px.requires_grad or py.requires_grad:
ans_grad = torch.ones(B, device=px.device, dtype=px.dtype)
(px_grad, py_grad) = _fast_rnnt.mutual_information_backward(
px, py, boundary, p, ans_grad)
ctx.save_for_backward(px_grad, py_grad)
assert len(pxy_grads) == 2
pxy_grads[0] = px_grad
pxy_grads[1] = py_grad
return ans
@staticmethod
def backward(
ctx, ans_grad: Tensor
) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
(px_grad, py_grad) = ctx.saved_tensors
(B,) = ans_grad.shape
ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1)
px_grad *= ans_grad
py_grad *= ans_grad
return (px_grad, py_grad, None, None, None)
def mutual_information_recursion(
px: Tensor,
py: Tensor,
boundary: Optional[Tensor] = None,
return_grad: bool = False,
) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]:
"""A recursion that is useful in computing mutual information between two
sequences of real vectors, but may be useful more generally in
sequence-to-sequence tasks where monotonic alignment between pairs of
sequences is desired. The definitions of the arguments are definitions that
would be used when computing this type of mutual information, but you can
also view them as arbitrary quantities and just make use of the formula
computed by this function.
Args:
px:
A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``,
where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence
(including representations of ``EOS`` symbols but not ``BOS`` symbols),
and ``S`` is the length of the ``y`` sequence (including representations
of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information
application, ``px[b][s][t]`` would represent the following log odds
ratio; ignoring the b index on the right to make the notation more
compact::
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an ``x`` value as opposed to a ``y`` value. In
practice it might be computed as ``a + b``, where ``a`` is the log
probability of choosing to extend the sequence of length ``(s,t)``
with an ``x`` as opposed to a ``y`` value; and ``b`` might in practice
be of the form::
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where ``N`` is the number of terms that the sum over ``t'`` included,
which might include some or all of the other sequences as well as this
one.
Note:
we don't require ``px`` and py to be contiguous, but the
code assumes for optimization purposes that the ``T`` axis has
stride 1.
py:
A torch.Tensor of the same dtype as ``px``, with shape ``[B][S+1][T]``,
representing::
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat ``x`` and ``y`` differently; the only
difference is that for optimization purposes we assume the last axis
(the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` are
contiguous.
boundary:
If supplied, a torch.LongTensor of shape ``[B][4]``, where each
row contains ``[s_begin, t_begin, s_end, t_end]``,
with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T``
(this implies that empty sequences are allowed).
If not supplied, the values ``[0, 0, S, T]`` will be assumed.
These are the beginning and one-past-the-last positions in the ``x`` and
``y`` sequences respectively, and can be used if not all sequences are
of the same length.
return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing for the
occupation probability is the output of the backward with a
``fake gradient`` the ``fake gradient`` is the same as the gradient
you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``.
This is useful to implement the pruned version of rnnt loss.
Returns:
Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual
information between the b'th pair of sequences. This is defined by
the following recursion on ``p[b,s,t]`` (where ``p`` is of shape
``[B,S+1,T+1]``), representing a mutual information between sub-sequences
of lengths ``s`` and ``t``::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
where we handle edge cases by treating quantities with negative indexes
as **-infinity**. The extension to cases where the boundaries are
specified should be obvious; it just works on shorter sequences with
offsets into ``px`` and ``py``.
"""
assert px.ndim == 3
B, S, T1 = px.shape
T = py.shape[-1]
assert px.shape[-1] in [T, T + 1] # if T, then "modified".
assert py.shape == (B, S + 1, T)
assert px.dtype == py.dtype
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for s_begin, t_begin, s_end, t_end in boundary.tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
# The following assertions are for efficiency
assert px.is_contiguous()
assert py.is_contiguous()
pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad)
px_grad, py_grad = pxy_grads
return (scores, (px_grad, py_grad)) if return_grad else scores
def _inner_product(a: Tensor, b: Tensor) -> Tensor:
"""
Does inner product on the last dimension, with expected broadcasting,
i.e. equivalent to (a * b).sum(dim=-1)
without creating a large temporary.
"""
assert a.shape[-1] == b.shape[-1] # The last dim must be equal
a = a.unsqueeze(-2) # (..., 1, K)
b = b.unsqueeze(-1) # (..., K, 1)
c = torch.matmul(a, b) # (..., 1, 1)
return c.squeeze(-1).squeeze(-1)
def joint_mutual_information_recursion(
px: Sequence[Tensor],
py: Sequence[Tensor],
boundary: Optional[Tensor] = None,
) -> Sequence[Tensor]:
"""A recursion that is useful for modifications of RNN-T and similar loss
functions, where the recursion probabilities have a number of terms and you
want them reported separately. See mutual_information_recursion() for more
documentation of the basic aspects of this.
Args:
px:
a sequence of Tensors, each of the same shape [B][S][T+1]
py:
a sequence of Tensor, each of the same shape [B][S+1][T],
the sequence must be the same length as px.
boundary:
optionally, a LongTensor of shape [B][4] containing rows
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S
and 0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
These are the beginning and one-past-the-last positions in the x
and y sequences respectively, and can be used if not all
sequences are of the same length.
Returns:
a Tensor of shape (len(px), B),
whose sum over dim 0 is the total log-prob of the recursion mentioned
below, per sequence. The first element of the sequence of length len(px)
is "special", in that it has an offset term reflecting the difference
between sum-of-log and log-of-sum; for more interpretable loss values,
the "main" part of your loss function should be first.
The recursion below applies if boundary == None, when it defaults
to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px
and py::
p = tensor of shape (B, S+1, T+1), containing -infinity
p[b,0,0] = 0.0
# do the following in loop over s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
p[b,s,t-1] + py_sum[b,s,t-1])
(if s > 0 or t > 0)
return b[:][S][T]
This function lets you implement the above recursion efficiently, except
that it gives you a breakdown of the contribution from all the elements of
px and py separately. As noted above, the first element of the
sequence is "special".
"""
N = len(px)
assert len(py) == N and N > 0
B, S, T1 = px[0].shape
T = py[0].shape[2]
assert T1 in [T, T + 1] # T if modified...
assert py[0].shape == (B, S + 1, T)
assert px[0].dtype == py[0].dtype
px_cat = torch.stack(
px, dim=0
) # (N, B, S, T+1) if !modified,(N, B, S, T) if modified.
py_cat = torch.stack(py, dim=0) # (N, B, S+1, T)
px_tot = px_cat.sum(dim=0) # (B, S, T+1)
py_tot = py_cat.sum(dim=0) # (B, S+1, T)
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for s_begin, t_begin, s_end, t_end in boundary.tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
# The following assertions are for efficiency
assert px_tot.ndim == 3
assert py_tot.ndim == 3
p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)
# note, tot_probs is without grad.
tot_probs = _fast_rnnt.mutual_information_forward(px_tot, py_tot, boundary, p)
# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)
(px_grad,
py_grad) = _fast_rnnt.mutual_information_backward(px_tot, py_tot, boundary, p,
ans_grad)
px_grad = px_grad.reshape(1, B, -1)
py_grad = py_grad.reshape(1, B, -1)
px_cat = px_cat.reshape(N, B, -1)
py_cat = py_cat.reshape(N, B, -1)
# get rid of -inf, would generate nan on product with 0
px_cat = px_cat.clamp(min=torch.finfo(px_cat.dtype).min)
py_cat = py_cat.clamp(min=torch.finfo(py_cat.dtype).min)
x_prods = _inner_product(px_grad, px_cat) # (N, B)
y_prods = _inner_product(py_grad, py_cat) # (N, B)
# If all the occupation counts were exactly 1.0 (i.e. no partial counts),
# "prods" should be equal to "tot_probs"; however, in general, "tot_probs"
# will be more positive due to the difference between log-of-sum and
# sum-of-log
prods = x_prods + y_prods # (N, B)
with torch.no_grad():
offset = tot_probs - prods.sum(dim=0) # (B,)
prods[0] += offset
return prods # (N, B)
# Copyright 2021 Xiaomi Corp. (author: Daniel Povey, Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import fast_rnnt
import torch
from torch import Tensor
from typing import Optional, Tuple, Union
from .mutual_information import mutual_information_recursion
def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
None. If boundary == None and modified == False, px[:,:,-1] will
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
to be -infinity.
Args:
px: a Tensor of of shape [B][S][T+1] (this function is only
called if modified == False, see other docs for `modified`)
px is modified in-place and returned.
boundary: None, or a Tensor of shape [B][3] containing
[s_begin, t_begin, s_end, t_end]; we need only t_end.
"""
if boundary is None:
return px
B, S, T1 = px.shape
boundary = boundary[:, 3].reshape(B, 1, 1).expand(B, S, T1)
return px.scatter_(dim=2, index=boundary, value=float("-inf"))
def get_rnnt_logprobs(
lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just
addition), to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This function is called from rnnt_loss_simple(), but may be useful for
other purposes.
Args:
lm:
Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape::
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am:
Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape::
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
of length s and t respectively. px[b][s][t] represents the
probability of extending the subsequences of length (s,t) by one in
the s direction, given the particular symbol, and py[b][s][t]
represents the probability of extending the subsequences of length
(s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim == 3
assert am.ndim == 3
assert lm.shape[0] == am.shape[0]
assert lm.shape[2] == am.shape[2]
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1]
lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1]
am_probs = (am - am_max).exp()
lm_probs = (lm - lm_max).exp()
# normalizers: [B][S+1][T]
normalizers = (
torch.matmul(lm_probs, am_probs.transpose(1, 2))
+ torch.finfo(am_probs.dtype).tiny
).log()
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T]
# px is the probs of the actual symbols..
px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C),
dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
).squeeze(
-1
) # [B][S][T]
if not modified:
px_am = torch.cat(
(
px_am,
torch.full(
(B, S, 1),
float("-inf"),
device=px_am.device,
dtype=px_am.dtype,
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
px = px_am + px_lm # [B][S][T+1], last slice with indexes out of
# boundary is -inf
px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am = am[:, :, termination_symbol].unsqueeze(1) # [B][1][T]
py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers
if not modified:
px = fix_for_boundary(px, boundary)
return (px, py)
def rnnt_loss_simple(
lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]:
"""A simple case of the RNN-T loss, where the 'joiner' network is just
addition.
Args:
lm:
language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes
am:
acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols:
the symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
return_grad:
Whether to return grads of px and py, this grad standing for the
occupation probability is the output of the backward with a
`fake gradient`, the `fake gradient` is the same as the gradient you'd
get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the
loss here is the loss with reduction "none".
This is useful to implement the pruned version of rnnt loss.
Returns:
If return_grad is False, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch if reduction equals
to "none", otherwise a scalar with the reduction applied.
If return_grad is True, the grads of px and py, which is the output of
backward with a `fake gradient`(see above), will be returned too. And the
returned value will be a tuple like (loss, (px_grad, py_grad)).
"""
px, py = get_rnnt_logprobs(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
loss = -negated_loss
elif reduction == "mean":
loss = -torch.mean(negated_loss)
elif reduction == "sum":
loss = -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
return (loss, scores_and_grads[1]) if return_grad else loss
def get_rnnt_logprobs_joint(
logits: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Reduces RNN-T problem to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This function is called from rnnt_loss().
Args:
logits:
The output of joiner network, with shape (B, T, S + 1, C),
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
(px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1]
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert logits.ndim == 4
(B, T, S1, C) = logits.shape
S = S1 - 1
assert symbols.shape == (B, S)
normalizers = torch.logsumexp(logits, dim=3)
normalizers = normalizers.permute((0, 2, 1))
px = torch.gather(
logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1)
).squeeze(-1)
px = px.permute((0, 2, 1))
if not modified:
px = torch.cat(
(
px,
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
px[:, :, :T] -= normalizers[:, :S, :]
py = (
logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone()
) # [B][S+1][T]
py -= normalizers
px = px.contiguous()
py = py.contiguous()
if not modified:
px = fix_for_boundary(px, boundary)
return (px, py)
def rnnt_loss(
logits: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
modified: bool = False,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A normal RNN-T loss, which uses a 'joiner' network output as input,
i.e. a 4 dimensions tensor.
Args:
logits:
The output of joiner network, with shape (B, T, S + 1, C),
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes
symbols:
The symbol sequences, a LongTensor of shape [B][S], and elements
in {0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
Returns:
If recursion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar
with the reduction applied.
"""
px, py = get_rnnt_logprobs_joint(
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none":
return -negated_loss
elif reduction == "mean":
return -torch.mean(negated_loss)
elif reduction == "sum":
return -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
def _adjust_pruning_lower_bound(
s_begin: torch.Tensor, s_range: int
) -> torch.Tensor:
"""Adjust s_begin (pruning lower bound) to make it satisfied the following
constrains
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- start with symbol 0 at first frame.
- s_begin[i + 1] - s_begin[i] < s_range, whicn means that we can't skip
any symbols.
To make it monotonic increasing, we can use `monotonic_lower_bound` function
in k2, which guarantee `s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`, the initial `min_value` set to
`inf`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constrain is a little tricky. We first transform `s_begin` with
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
then we make the transformed `s_begin` monotonic increasing, after that,
we transform back `s_begin` with the same formula as the previous
transformation. The idea is: if we want to make
`s_begin[i + 1] - s_begin[i] < s_range` we only need to make
`-(s_begin[i] - i * (s_range - 1))` a non-decreasing array. Proof:
-(s_begin[i] - i * (s_range - 1)) <= -(s_begin[i + 1] - (i + 1) * (s_range - 1))
-s_begin[i] <= -s_begin[i + 1] + (i + 1) * (s_range - 1) - i * (s_range - 1)
-s_begin[i] <= -s_begin[i + 1] + s_range - 1
s_begin[i + 1] - s_begin[i] <= s_range - 1
s_begin[i + 1] - s_begin[i] < s_range
The above transformation can not guarantee the start symbol to be 0, so we
have to make all the elements that less than 0 to be 0 before transforming
back the `s_begin`.
"""
# s_begin (B, T)
(B, T) = s_begin.shape
fast_rnnt.monotonic_lower_bound_(s_begin)
# do the magic transformation
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
# make the transformed tensor to be non-decreasing
fast_rnnt.monotonic_lower_bound_(s_begin)
# make start symbol to be zero.
s_begin = torch.where(s_begin < 0, 0, s_begin)
# do the magic transformation again to recover s_begin
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
return s_begin
def get_rnnt_prune_ranges(
px_grad: torch.Tensor,
py_grad: torch.Tensor,
boundary: torch.Tensor,
s_range: int,
) -> torch.Tensor:
"""Get the pruning ranges of normal rnnt loss according to the grads
of px and py returned by mutual_information_recursion.
For each sequence with T frames, we will generate a tensor with the shape of
(T, s_range) containing the information that which symbols will be token
into consideration for each frame. For example, here is a sequence with 10
frames and the corresponding symbols are `[A B C D E F]`, if the s_range
equals 3, one possible ranges tensor will be::
[[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [1, 2, 3],
[1, 2, 3], [1, 2, 3], [3, 4, 5], [3, 4, 5], [3, 4, 5]]
which means we only consider `[A B C]` at frame 0, 1, 2, 3, and `[B C D]`
at frame 4, 5, 6, `[D E F]` at frame 7, 8, 9.
We can only consider limited number of symbols because frames and symbols
are monotonic aligned, theoretically it can only generate particular range
of symbols given a particular frame.
Note:
For the generated tensor ranges, ranges[:, 0] is a monotonic increasing
tensor from 0 to `len(symbols)` and it satisfies
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any
symbols.
Args:
px_grad:
The gradient of px, see docs in `mutual_information_recursion` for more
details of px.
py_grad:
The gradient of py, see docs in `mutual_information_recursion` for more
details of py.
boundary:
a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame]
s_range:
How many symbols to keep for each frame.
Returns:
A tensor contains the kept symbols indexes for each frame, with shape
(B, T, s_range).
"""
(B, S, T1) = px_grad.shape
T = py_grad.shape[-1]
assert T1 in [T, T + 1]
assert py_grad.shape == (B, S + 1, T)
assert boundary.shape == (B, 4)
assert s_range >= 1
if s_range > S:
s_range = S
px_pad = torch.zeros((B, 1, T1), dtype=px_grad.dtype, device=px_grad.device)
py_pad = torch.zeros(
(B, S + 1, 1), dtype=py_grad.dtype, device=py_grad.device
)
py_grad_padded = py_grad if T1 == T else torch.cat((py_grad, py_pad), dim=2)
tot_grad = (
torch.cat((px_grad, px_pad), dim=1) + py_grad_padded
) # (B, S + 1, T1)
tot_grad = torch.cat(
(
torch.zeros(
(B, 1, T1), dtype=tot_grad.dtype, device=tot_grad.device
),
tot_grad,
),
dim=1,
)
tot_grad = torch.cumsum(tot_grad, dim=1)
diff_grad = tot_grad[:, s_range:, :] - tot_grad[:, 0:-s_range, :]
s_begin = torch.argmax(diff_grad, dim=1)
s_begin = s_begin[:, :T]
# Handle the values of s_begin in padding positions.
# -1 here means we fill the position of the last frame of real data with
# padding value which is `len(symbols) - s_range + 1`.
# This is to guarantee that we reach the last symbol at last frame of real
# data.
mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T)
mask = mask < boundary[:, 3].reshape(B, 1) - 1
s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1
# handle the cases when `len(symbols) < s_range`
s_begin_padding = torch.where(s_begin_padding >= 0, s_begin_padding, 0)
s_begin = torch.where(mask, s_begin, s_begin_padding)
# adjusting lower bound to make it satisfied some constrains, see docs in
# `adjust_pruning_lower_bound` for more details of these constrains.
# T1 == T here means we are using the modified version of transducer,
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# it only emits one symbol per frame.
s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range)
ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange(
s_range, device=px_grad.device
)
return ranges
def do_rnnt_pruning(
am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prune the output of encoder(am) output and prediction network(lm)
output of RNNT.
Args:
am:
The encoder output, with shape (B, T, C)
lm:
The prediction network output, with shape (B, S + 1, C)
ranges:
A tensor containing the symbol indexes for each frame that we want to
keep. Its shape is (B, T, s_range), see the docs in
`get_rnnt_prune_ranges` for more details of this tensor.
Returns:
Return the pruned am and lm with shape (B, T, s_range, C)
"""
# am (B, T, C)
# lm (B, S + 1, C)
# ranges (B, T, s_range)
assert ranges.shape[0] == am.shape[0]
assert ranges.shape[0] == lm.shape[0]
assert am.shape[1] == ranges.shape[1]
(B, T, s_range) = ranges.shape
(B, S1, C) = lm.shape
S = S1 - 1
# (B, T, s_range, C)
am_pruning = am.unsqueeze(2).expand((B, T, s_range, C))
# (B, T, s_range, C)
lm_pruning = torch.gather(
lm.unsqueeze(1).expand((B, T, S + 1, C)),
dim=2,
index=ranges.reshape((B, T, s_range, 1)).expand((B, T, s_range, C)),
)
return am_pruning, lm_pruning
def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
"""Roll tensor with different shifts for each row.
Note:
We assume the src is a 3 dimensions tensor and roll the last dimension.
Example:
>>> src = torch.arange(15).reshape((1,3,5))
>>> src
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]]])
>>> shift = torch.tensor([[1, 2, 3]])
>>> shift
tensor([[1, 2, 3]])
>>> _roll_by_shifts(src, shift)
tensor([[[ 4, 0, 1, 2, 3],
[ 8, 9, 5, 6, 7],
[12, 13, 14, 10, 11]]])
"""
assert src.dim() == 3
(B, T, S) = src.shape
assert shifts.shape == (B, T)
index = (
torch.arange(S, device=src.device)
.view((1, S))
.repeat((T, 1))
.repeat((B, 1, 1))
)
index = (index - shifts.reshape(B, T, 1)) % S
return torch.gather(src, 2, index)
def get_rnnt_logprobs_pruned(
logits: Tensor,
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor,
modified: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output.
Args:
logits:
The pruned output of joiner network, with shape (B, T, s_range, C)
symbols:
The symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
Return the px (B, S, T) if modified else (B, S, T + 1) and
py (B, S + 1, T) needed by mutual_information_recursion.
"""
# logits (B, T, s_range, C)
# symbols (B, S)
# ranges (B, T, s_range)
assert logits.ndim == 4
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range)
(B, S) = symbols.shape
normalizers = torch.logsumexp(logits, dim=3)
symbols_with_terminal = torch.cat(
(
symbols,
torch.tensor(
[termination_symbol] * B,
dtype=torch.int64,
device=symbols.device,
).reshape((B, 1)),
),
dim=1,
)
# (B, T, s_range)
pruned_symbols = torch.gather(
symbols_with_terminal.unsqueeze(1).expand((B, T, S + 1)),
dim=2,
index=ranges,
)
# (B, T, s_range)
px = torch.gather(
logits, dim=3, index=pruned_symbols.reshape(B, T, s_range, 1)
).squeeze(-1)
px = px - normalizers
# (B, T, S) with index larger than s_range in dim 2 fill with -inf
px = torch.cat(
(
px,
torch.full(
(B, T, S + 1 - s_range),
float("-inf"),
device=px.device,
dtype=px.dtype,
),
),
dim=2,
)
# (B, T, S) with index out of s_range in dim 2 fill with -inf
px = _roll_by_shifts(px, ranges[:, :, 0])[:, :, :S]
px = px.permute((0, 2, 1))
if not modified:
px = torch.cat(
(
px,
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
py = logits[:, :, :, termination_symbol].clone() # (B, T, s_range)
py = py - normalizers
# (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf
py = torch.cat(
(
py,
torch.full(
(B, T, S + 1 - s_range),
float("-inf"),
device=py.device,
dtype=py.dtype,
),
),
dim=2,
)
# (B, T, S + 1) with index out of s_range in dim 2 fill with -inf
py = _roll_by_shifts(py, ranges[:, :, 0])
# (B, S + 1, T)
py = py.permute((0, 2, 1))
px = px.contiguous()
py = py.contiguous()
if not modified:
px = fix_for_boundary(px, boundary)
return (px, py)
def rnnt_loss_pruned(
logits: Tensor,
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor = None,
modified: bool = False,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
s_range means the symbols number kept for each frame.
Args:
logits:
The pruned output of joiner network, with shape (B, T, s_range, C),
i.e. batch, time_seq_len, prune_range, num_classes
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
boundary:
a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
Returns:
If recursion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar
with the reduction applied.
"""
px, py = get_rnnt_logprobs_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none":
return -negated_loss
elif reduction == "mean":
return -torch.mean(negated_loss)
elif reduction == "sum":
return -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
def get_rnnt_logprobs_smoothed(
lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None,
modified: bool = False,
) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just
addition), to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
This version allows you to make the loss-function one of the form::
lm_only_scale * lm_probs +
am_only_scale * am_probs +
(1-lm_only_scale-am_only_scale) * combined_probs
where lm_probs and am_probs are the probabilities given the lm and acoustic
model independently.
This function is called from
:func:`rnnt_loss_smoothed`, but may be useful for other purposes.
Args:
lm:
Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape::
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am:
Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape::
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols:
A LongTensor of shape [B][S], containing the symbols at each position
of the sequence.
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
lm_only_scale:
the scale on the "LM-only" part of the loss.
am_only_scale:
the scale on the "AM-only" part of the loss, for which we use
an "averaged" LM (averaged over all histories, so effectively unigram).
boundary:
a optional LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
Returns:
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
of length s and t respectively. px[b][s][t] represents the
probability of extending the subsequences of length (s,t) by one in
the s direction, given the particular symbol, and py[b][s][t]
represents the probability of extending the subsequences of length
(s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim == 3
assert am.ndim == 3
assert lm.shape[0] == am.shape[0]
assert lm.shape[2] == am.shape[2]
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
# all of the logprobs here are properly normalized. We test that
# this code is invariant to adding constants in the appropriate ways.
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1]
lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1]
am_probs = (am - am_max).exp() # [B][T][C]
lm_probs = (lm - lm_max).exp() # [B][S+1][C]
# normalizers: [B][S+1][T]
normalizers = (
torch.matmul(lm_probs, am_probs.transpose(1, 2))
+ torch.finfo(lm_probs.dtype).tiny
).log()
# normalizer per frame, if we take only the LM probs by themselves
lmonly_normalizers = lm_probs.sum(
dim=2, keepdim=True
) # lmonly_normalizers: [B][S+1][1]
unigram_lm = (
torch.mean(lm_probs / lmonly_normalizers, dim=(0, 1), keepdim=True)
+ torch.finfo(lm_probs.dtype).tiny
) # [1][1][C]
amonly_normalizers = (
torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C))
.reshape(B, T, 1)
.log()
+ am_max
) # [B][T][1]
amonly_normalizers = amonly_normalizers.transpose(1, 2) # [B][1][T]
unigram_lm = unigram_lm.log()
lmonly_normalizers = (
lmonly_normalizers.log() + lm_max
) # [B][S+1][1], log-normalizer, used for LM-only part of prob.
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T]
# px is the probs of the actual symbols (not yet normalized)..
px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C),
dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
).squeeze(
-1
) # [B][S][T]
if not modified:
px_am = torch.cat(
(
px_am,
torch.full(
(B, S, 1),
float("-inf"),
device=px_am.device,
dtype=px_am.dtype,
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
px_lm_unigram = torch.gather(
unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
px = px_am + px_lm # [B][S][T+1] if not modified, [B][S][T] if modified
px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] or [B][S][T]
px_amonly = (
px_am + px_lm_unigram
) # [B][S][T+1] if !modified; [B][S][T] if modified.
px_amonly[:, :, :T] -= amonly_normalizers
px_lmonly = px_lm - lmonly_normalizers[:, :S, :]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am = am[:, :, termination_symbol].unsqueeze(1) # [B][1][T]
py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers
py_lm_unigram = unigram_lm[0][0][termination_symbol] # scalar, normalized..
py_amonly = py_am + py_lm_unigram - amonly_normalizers # [B][S+1][T]
py_lmonly = py_lm - lmonly_normalizers # [B][S+1][T]
combined_scale = 1.0 - lm_only_scale - am_only_scale
# We need to avoid exact zeros in the scales because otherwise multiplying
# -inf by zero generates nan.
if lm_only_scale == 0.0:
lm_only_scale = 1.0e-20
if am_only_scale == 0.0:
am_only_scale = 1.0e-20
px_interp = (
px * combined_scale
+ px_lmonly * lm_only_scale
+ px_amonly * am_only_scale
)
py_interp = (
py * combined_scale
+ py_lmonly * lm_only_scale
+ py_amonly * am_only_scale
)
if not modified:
px_interp = fix_for_boundary(px_interp, boundary)
return (px_interp, py_interp)
def rnnt_loss_smoothed(
lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None,
modified: bool = False,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]:
"""A simple case of the RNN-T loss, where the 'joiner' network is just
addition.
Args:
lm:
language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes.
These are assumed to be well-normalized, in the sense that we could
use them as probabilities separately from the am scores
am:
acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols:
the symbol sequences, a LongTensor of shape [B][S], and elements in
{0..C-1}.
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
lm_only_scale:
the scale on the "LM-only" part of the loss.
am_only_scale:
the scale on the "AM-only" part of the loss, for which we use
an "averaged" LM (averaged over all histories, so effectively unigram).
boundary:
a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
return_grad:
Whether to return grads of px and py, this grad standing for the
occupation probability is the output of the backward with a
`fake gradient`, the `fake gradient` is the same as the gradient you'd
get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the
loss here is the loss with reduction "none".
This is useful to implement the pruned version of rnnt loss.
Returns:
If return_grad is False, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch if reduction equals
to "none", otherwise a scalar with the reduction applied.
If return_grad is True, the grads of px and py, which is the output of
backward with a `fake gradient`(see above), will be returned too. And the
returned value will be a tuple like (loss, (px_grad, py_grad)).
"""
px, py = get_rnnt_logprobs_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=lm_only_scale,
am_only_scale=am_only_scale,
boundary=boundary,
modified=modified,
)
scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
loss = -negated_loss
elif reduction == "mean":
loss = -torch.mean(negated_loss)
elif reduction == "sum":
loss = -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
return (loss, scores_and_grads[1]) if return_grad else loss
function(fast_rnnt_add_py_test source)
get_filename_component(name ${source} NAME_WE)
set(name "${name}_py")
add_test(NAME ${name}
COMMAND
"${PYTHON_EXECUTABLE}"
"${CMAKE_CURRENT_SOURCE_DIR}/${source}"
)
get_filename_component(fast_rnnt_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
set_property(TEST ${name}
PROPERTY ENVIRONMENT "PYTHONPATH=${fast_rnnt_path}:$<TARGET_FILE_DIR:_fast_rnnt>:$ENV{PYTHONPATH}"
)
endfunction()
# please sort the files in alphabetic order
set(py_test_files
mutual_information_test.py
rnnt_loss_test.py
)
foreach(source IN LISTS py_test_files)
fast_rnnt_add_py_test(${source})
endforeach()
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (authors: Daniel Povey,
# Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run this single test, use
#
# ctest --verbose -R mutual_information_test_py
import random
import unittest
import fast_rnnt
import torch
# Caution: this will fail occasionally due to cutoffs not being quite large
# enough. As long as it passes most of the time, it's OK.
class TestMutualInformation(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.devices = [torch.device("cpu")]
if torch.cuda.is_available() and fast_rnnt.with_cuda():
cls.devices.append(torch.device("cuda", 0))
if torch.cuda.device_count() > 1:
torch.cuda.set_device(1)
cls.devices.append(torch.device("cuda", 1))
cls.dtypes = [torch.float32, torch.float64]
def test_mutual_information_basic(self):
for _iter in range(10):
(B, S, T) = (
random.randint(1, 10),
random.randint(1, 16),
random.randint(1, 500),
)
random_px = random.random() < 0.2
random_py = random.random() < 0.2
random_boundary = random.random() < 0.7
big_px = random.random() < 0.2
big_py = random.random() < 0.2
modified = random.random() < 0.5
if modified and T < S:
T = S + random.randint(0, 30)
for dtype in self.dtypes:
for device in self.devices:
if random_boundary:
def get_boundary_row():
this_S = random.randint(
0, S
) # allow empty sequence
this_T = random.randint(
this_S if modified else 1, T
)
s_begin = random.randint(0, S - this_S)
t_begin = random.randint(0, T - this_T)
s_end = s_begin + this_S
t_end = t_begin + this_T
return [s_begin, t_begin, s_end, t_end]
if device == torch.device("cpu"):
boundary = torch.tensor(
[get_boundary_row() for _ in range(B)],
dtype=torch.int64,
device=device,
)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly
# or not.
if random.random() < 0.5:
boundary = (
torch.tensor([0, 0, S, T], dtype=torch.int64)
.unsqueeze(0)
.expand(B, 4)
.to(device)
)
else:
boundary = None
if device == torch.device("cpu"):
if random_px:
# log of an odds ratio
px = torch.randn(
B, S, T + (0 if modified else 1), dtype=dtype
).to(device)
if S > 1 and not random_boundary and not modified:
px[:, :, -1:] = float("-inf")
else:
# log of an odds ratio
px = torch.zeros(
B, S, T + (0 if modified else 1), dtype=dtype
).to(device)
# px and py get exponentiated, and then multiplied
# together up to 32 times (BLOCK_SIZE in the CUDA code),
# so 15 is actually a big number that could lead to
# overflow.
if big_px:
px += 15.0
if random_py:
# log of an odds ratio
py = torch.randn(B, S + 1, T, dtype=dtype).to(
device
)
else:
# log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(
device
)
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
m = fast_rnnt.mutual_information_recursion(px, py, boundary)
m2 = fast_rnnt.joint_mutual_information_recursion(
(px,), (py,), boundary
)
m3 = fast_rnnt.joint_mutual_information_recursion(
(px * 0.5, px * 0.5), (py * 0.5, py * 0.5), boundary
)
# it is supposed to be identical only after
# summing over dim 0, corresponding to the
# sequence dim
m3 = m3.sum(dim=0)
assert torch.allclose(m, m2)
assert torch.allclose(m, m3)
# the loop this is in checks that the CPU and CUDA versions
# give the same derivative;
# by randomizing which of m, m2 or m3 we backprop, we also
# ensure that the joint version of the code gives the same
# derivative as the regular version
scale = 3
if random.random() < 0.5:
(m.sum() * scale).backward()
elif random.random() < 0.5:
(m2.sum() * scale).backward()
else:
(m3.sum() * scale).backward()
if device == torch.device("cpu"):
expected_px_grad = px.grad
expected_py_grad = py.grad
expected_m = m
assert torch.allclose(
px.grad,
expected_px_grad.to(device),
atol=1.0e-02,
rtol=1.0e-02,
)
assert torch.allclose(
py.grad,
expected_py_grad.to(device),
atol=1.0e-02,
rtol=1.0e-02,
)
assert torch.allclose(
m, expected_m.to(device), atol=1.0e-02, rtol=1.0e-02
)
def test_mutual_information_deriv(self):
for _iter in range(10):
(B, S, T) = (
random.randint(1, 100),
random.randint(1, 200),
random.randint(1, 200),
)
random_px = random.random() < 0.2
random_py = random.random() < 0.2
random_boundary = random.random() < 0.7
big_px = random.random() < 0.2
big_py = random.random() < 0.2
modified = random.random() < 0.5
if modified and T < S:
T = S + random.randint(0, 30)
for dtype in self.dtypes:
for device in self.devices:
if random_boundary:
def get_boundary_row():
this_S = random.randint(1, S)
this_T = random.randint(
this_S if modified else 1, T
)
s_begin = random.randint(0, S - this_S)
t_begin = random.randint(0, T - this_T)
s_end = s_begin + this_S
t_end = t_begin + this_T
return [s_begin, t_begin, s_end, t_end]
if device == torch.device("cpu"):
boundary = torch.tensor(
[get_boundary_row() for _ in range(B)],
dtype=torch.int64,
device=device,
)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly
# or not.
if random.random() < 0.5:
boundary = (
torch.tensor([0, 0, S, T], dtype=torch.int64)
.unsqueeze(0)
.expand(B, 4)
.to(device)
)
else:
boundary = None
T1 = T + (0 if modified else 1)
if device == torch.device("cpu"):
if random_px:
# log of an odds ratio
px = torch.randn(B, S, T1, dtype=dtype).to(device)
else:
# log of an odds ratio
px = torch.zeros(B, S, T1, dtype=dtype).to(device)
# px and py get exponentiated, and then multiplied
# together up to 32 times (BLOCK_SIZE in the CUDA code),
# so 15 is actually a big number that could lead to
# overflow.
if big_px:
px += 15.0
if random_py:
# log of an odds ratio
py = torch.randn(B, S + 1, T, dtype=dtype).to(
device
)
else:
# log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(
device
)
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
m = fast_rnnt.mutual_information_recursion(px, py, boundary)
m_grad = torch.randn(B, dtype=dtype, device=device)
m.backward(gradient=m_grad)
delta = 1.0e-04
delta_px = delta * torch.randn_like(px)
m2 = fast_rnnt.mutual_information_recursion(
px + delta_px, py, boundary
)
delta_m = m2 - m
observed_delta = (delta_m * m_grad).sum().to("cpu")
predicted_delta = (delta_px * px.grad).sum().to("cpu")
atol = 1.0e-02 if dtype == torch.float32 else 1.0e-04
rtol = 1.0e-02 if dtype == torch.float32 else 1.0e-04
assert torch.allclose(
observed_delta, predicted_delta, atol=atol, rtol=rtol
)
delta_py = delta * torch.randn_like(py)
m2 = fast_rnnt.mutual_information_recursion(
px, py + delta_py, boundary
)
delta_m = m2 - m
observed_delta = (delta_m * m_grad).sum().to("cpu")
predicted_delta = (delta_py * py.grad).sum().to("cpu")
assert torch.allclose(
observed_delta, predicted_delta, atol=atol, rtol=rtol
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment