Unverified Commit a283cddf authored by Daniel Povey's avatar Daniel Povey Committed by GitHub
Browse files

Merge pull request #4 from pkufool/fast_rnnt

Sync with k2 rnnt_loss
parents b5828e2b 182fe8de
include_directories(${CMAKE_SOURCE_DIR})
# it is located in fast_rnnt/cmake/transform.cmake
include(transform)
set(srcs
mutual_information_cpu.cu
utils.cu
)
if(NOT FT_WITH_CUDA)
transform(OUTPUT_VARIABLE srcs SRCS ${srcs})
else()
list(APPEND srcs mutual_information_cuda.cu)
endif()
add_library(mutual_information_core ${srcs})
target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES})
# for <torch/extension.h>
target_include_directories(mutual_information_core PUBLIC ${PYTHON_INCLUDE_DIRS})
# see https://github.com/NVIDIA/thrust/issues/1401
# and https://github.com/k2-fsa/k2/pull/917
target_compile_definitions(mutual_information_core PUBLIC CUB_WRAPPED_NAMESPACE=fast_rnnt)
target_compile_definitions(mutual_information_core PUBLIC THRUST_NS_QUALIFIER=thrust)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
target_link_libraries(mutual_information_core PUBLIC cub)
endif()
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang, Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_DEVICE_GUARD_H_
#define FAST_RNNT_CSRC_DEVICE_GUARD_H_
#include "torch/script.h"
// This file is modified from
// https://github.com/k2-fsa/k2/blob/master/k2/csrc/device_guard.h
namespace fast_rnnt {
// DeviceGuard is an RAII class. Its sole purpose is to restore
// the previous default cuda device if a CUDA context changes the
// current default cuda device.
class DeviceGuard {
public:
explicit DeviceGuard(torch::Device device) {
if (device.type() == torch::kCUDA) {
old_device_ = GetDevice();
new_device_ = device.index();
if (old_device_ != new_device_)
SetDevice(new_device_);
}
// else do nothing
}
explicit DeviceGuard(int32_t new_device) : new_device_(new_device) {
if (new_device != -1) {
old_device_ = GetDevice();
if (old_device_ != new_device)
SetDevice(new_device);
}
}
~DeviceGuard() {
if (old_device_ != -1 && old_device_ != new_device_) {
// restore the previous device
SetDevice(old_device_);
}
// else it was either a CPU context or the device IDs
// were the same
}
DeviceGuard(const DeviceGuard &) = delete;
DeviceGuard &operator=(const DeviceGuard &) = delete;
DeviceGuard(DeviceGuard &&) = delete;
DeviceGuard &operator=(DeviceGuard &&) = delete;
private:
static int32_t GetDevice() {
#ifdef FT_WITH_CUDA
int32_t device;
auto s = cudaGetDevice(&device);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
return device;
#else
return -1;
#endif
}
static void SetDevice(int32_t device) {
#ifdef FT_WITH_CUDA
auto s = cudaSetDevice(device);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
return;
#endif
}
private:
int32_t old_device_ = -1;
int32_t new_device_ = -1;
};
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_DEVICE_GUARD_H_
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#include <cmath>
#include <vector>
#include "torch/extension.h"
#ifdef __CUDA_ARCH__
#define FT_CUDA_HOSTDEV __host__ __device__
#else
#define FT_CUDA_HOSTDEV
#endif
namespace fast_rnnt {
FT_CUDA_HOSTDEV inline double LogAdd(double x, double y) {
double diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
// returns log(exp(x) + exp(y)).
FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
/*
Forward of mutual_information. See also comment of `mutual_information`
in ../pyhton/fast_rnnt/mutual_information.py. This is the core recursion
in the sequence-to-sequence mutual information computation.
@param px Tensor of shape [B][S][T + 1] if not modified, [B][S][T] if
modified. `modified` can be worked out from this. In not-modified case,
it can be thought of as the log-odds ratio of generating the next x in
the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of
lengths (s, t), divided by the prior probability of generating x_s.
(See mutual_information.py for more info).
@param py The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
@param p This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
if not modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
... treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b]
equals [s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last)
of the x and y sequences that we should process.
Alternatively, may be a tensor of shape [0][0] and type
int64_t; the elements will default to (0, 0, S, T).
@return A tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor MutualInformationCpu(
torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
torch::Tensor MutualInformationCuda(
torch::Tensor px, // [B][S][T+1] if !modified, [B][S][T] if modified.
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor>
MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad);
std::vector<torch::Tensor> MutualInformationBackwardCuda(
torch::Tensor px, torch::Tensor py, torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "fast_rnnt/csrc/mutual_information.h"
namespace fast_rnnt {
// forward of mutual_information. See """... """ comment of
// `mutual_information_recursion` in
// in k2/python/k2/mutual_information.py for documentation of the
// behavior of this function.
// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
// `modified` from this.
// py: of shape [B, S+1, T]
// boundary: of shape [B, 4], containing (s_begin, t_begin, s_end, t_end)
// defaulting to (0, 0, S, T).
// p: of shape (S+1, T+1)
// Computes the recursion:
// if !modified:
// p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
// p[b,s,t-1] + py[b,s,t-1])
// if modified:
// p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
// p[b,s,t-1] + py[b,s,t-1])
// .. treating out-of-range elements as -infinity and with special cases:
// p[b, s_begin, t_begin] = 0.0
//
// and this function returns a tensor of shape (B,) consisting of elements
// p[b, s_end, t_end]
torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
p.device().is_cpu(),
"inputs must be CPU tensors");
bool modified = (px.size(2) == py.size(2));
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), S = px.size(1), T = py.size(2);
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, T},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts);
AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(),
py_a = py.accessor<scalar_t, 3>(), p_a = p.accessor<scalar_t, 3>();
auto boundary_a = boundary.accessor<int64_t, 2>();
auto ans_a = ans.accessor<scalar_t, 1>();
int t_offset = (modified ? -1 : 0);
for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int t_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int t_end = boundary_a[b][3];
p_a[b][s_begin][t_begin] = 0.0;
if (modified) {
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] = -std::numeric_limits<scalar_t>::infinity();
} else {
// note: t_offset = 0 so don't need t_begin + t_offset below.
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] =
p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
}
for (int t = t_begin + 1; t <= t_end; ++t)
p_a[b][s_begin][t] =
p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
for (int s = s_begin + 1; s <= s_end; ++s) {
scalar_t p_s_t1 = p_a[b][s][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t) {
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(
// p_a[b][s - 1][t + t_offset] + px_a[b][s -1][t + t_offset],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a[b][s][t] = p_s_t1 = LogAdd(p_a[b][s - 1][t + t_offset] +
px_a[b][s - 1][t + t_offset],
p_s_t1 + py_a[b][s][t - 1]);
}
}
ans_a[b] = p_a[b][s_end][t_end];
}
}));
return ans;
}
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std::vector<torch::Tensor>
MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary,
torch::Tensor p, torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional.");
bool modified = (px.size(2) == py.size(2));
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
p.device().is_cpu() && ans_grad.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), S = px.size(1), T = py.size(2);
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, T},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
bool has_boundary = opt_boundary.has_value();
int T1 = T + (modified ? 0 : 1);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts)
: torch::empty({B, S, T1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts)
: torch::empty({B, S + 1, T}, opts));
AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(), p_a = p.accessor<scalar_t, 3>(),
p_grad_a = p_grad.accessor<scalar_t, 3>(),
px_grad_a = px_grad.accessor<scalar_t, 3>(),
py_grad_a = py_grad.accessor<scalar_t, 3>();
auto ans_grad_a = ans_grad.accessor<scalar_t, 1>();
auto boundary_a = boundary.accessor<int64_t, 2>();
int t_offset = (modified ? -1 : 0);
for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int t_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int t_end = boundary_a[b][3];
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a[b][s_end][t_end] = ans_grad_a[b];
for (int s = s_end; s > s_begin; --s) {
for (int t = t_end; t > t_begin; --t) {
// The s,t indexes correspond to
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(
// p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t + t_offset] +
px_a[b][s - 1][t + t_offset],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total = p_a[b][s][t];
if (total - total != 0)
total = 0;
scalar_t term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t];
scalar_t term1_grad, term2_grad;
if (term1_deriv - term1_deriv == 0.0) {
term1_grad = term1_deriv * grad;
term2_grad = term2_deriv * grad;
} else {
// could happen if total == -inf
term1_grad = term2_grad = 0.0;
}
px_grad_a[b][s - 1][t + t_offset] = term1_grad;
p_grad_a[b][s - 1][t + t_offset] = term1_grad;
py_grad_a[b][s][t - 1] = term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
}
}
for (int t = t_end; t > t_begin; --t) {
// Backprop for:
// p_a[b][s_begin][t] =
// p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t this_p_grad = p_grad_a[b][s_begin][t];
p_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] = this_p_grad;
}
if (!modified) {
for (int s = s_end; s > s_begin; --s) {
// Backprop for:
// p_a[b][s][t_begin] =
// p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t this_p_grad = p_grad_a[b][s][t_begin];
p_grad_a[b][s - 1][t_begin] += this_p_grad;
px_grad_a[b][s - 1][t_begin] = this_p_grad;
}
} // else these were all -infinity's and there is nothing to
// backprop.
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b];
if (fabs(grad_ratio - 1.0) > 0.01) {
std::cout
<< "Warning: mutual_information backprop: expected these "
<< "numbers to be the same:"
<< static_cast<float>(p_grad_a[b][s_begin][t_begin]) << " vs "
<< static_cast<float>(ans_grad_a[b]);
}
}
}
}));
return std::vector<torch::Tensor>({px_grad, py_grad});
}
} // namespace fast_rnnt
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/utils.h"
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src) {
TORCH_CHECK(src.dim() == 1, "Only support one dimension tensor");
TORCH_CHECK(src.scalar_type() == torch::kLong, "Only support LongTensor");
TORCH_CHECK(src.is_contiguous(), "Expected to be contiguous");
int32_t dim = src.numel();
if (src.device().type() == torch::kCPU) {
int64_t min_value = std::numeric_limits<int64_t>::max();
int64_t *src_data = src.data_ptr<int64_t>();
for (int32_t i = dim - 1; i >= 0; --i) {
min_value = std::min(src_data[i], min_value);
src[i] = min_value;
}
} else {
#ifdef FT_WITH_CUDA
TORCH_CHECK(src.device().is_cuda());
internal::MinOp<int64_t> min_op;
auto src_data = src.data_ptr<int64_t>();
internal::ConstReversedPtr<int64_t> src_ptr =
internal::ConstReversedPtr<int64_t>(src_data, dim);
internal::ReversedPtr<int64_t> dest_ptr =
internal::ReversedPtr<int64_t>(src_data, dim);
// The first time is to determine temporary device storage requirements.
std::size_t temp_storage_bytes = 0;
auto s = cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
auto d_temp = torch::empty({static_cast<int64_t>(temp_storage_bytes)},
torch::dtype(torch::kInt8).device(src.device()));
int8_t *d_temp_storage = d_temp.data_ptr<int8_t>();
s = cub::DeviceScan::InclusiveScan(d_temp_storage, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
TORCH_CHECK(false, "Please build with -DFT_WITH_CUDA=ON");
#endif // FT_WITH_CUDA
}
}
} // namespace fast_rnnt
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_UTILS_H_
#define FAST_RNNT_CSRC_UTILS_H_
#include "torch/extension.h"
#ifdef FT_WITH_CUDA
#include "cub/cub.cuh" // NOLINT
namespace fast_rnnt {
namespace internal {
template <typename T> struct MinOp {
__host__ __device__ __forceinline__ T operator()(const T &a,
const T &b) const {
return (a > b) ? b : a;
}
};
// Will be used (as both InputIterator and OutputIterator) in
// MonotonicLowerBound to call cub::DeviceScan::InclusiveScan.
template <typename T> struct ConstReversedPtr {
const T *data;
// data points to the last element now
explicit ConstReversedPtr(const T *data, int32_t size)
: data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ const T &operator[](int32_t i) const {
return data[-i];
}
__host__ __device__ __forceinline__ ConstReversedPtr
operator+(int32_t n) const {
ConstReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ const T &operator*() const {
return *data;
}
};
template <typename T> struct ReversedPtr {
T *data;
// data points to the last element now
explicit ReversedPtr(T *data, int32_t size) : data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ T &operator[](int32_t i) {
return data[-i];
}
__host__ __device__ __forceinline__ ReversedPtr operator+(int32_t n) const {
ReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ T &operator*() { return *data; }
};
} // namespace internal
} // namespace fast_rnnt
namespace std {
// vaule_type is required by cub::DeviceScan::InclusiveSum
template <typename T>
struct iterator_traits<fast_rnnt::internal::ConstReversedPtr<T>> {
typedef T value_type;
};
template <typename T>
struct iterator_traits<fast_rnnt::internal::ReversedPtr<T>> {
typedef T value_type;
};
} // namespace std
#endif // FT_WITH_CUDA
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_UTILS_H_
add_subdirectory(csrc)
add_subdirectory(tests)
include_directories(${CMAKE_SOURCE_DIR})
include(transform)
# please keep the list sorted
set(fast_rnnt_srcs
fast_rnnt.cu
mutual_information.cu
utils.cu
)
if(NOT FT_WITH_CUDA)
transform(OUTPUT_VARIABLE fast_rnnt_srcs SRCS ${fast_rnnt_srcs})
endif()
pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs})
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)
if(UNIX AND NOT APPLE)
target_link_libraries(_fast_rnnt
PRIVATE
${PYTHON_LIBRARY}
${TORCH_DIR}/lib/libtorch_python.so
)
endif()
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/utils.h"
PYBIND11_MODULE(_fast_rnnt, m) {
m.doc() = "Python wrapper for Fast Rnnt.";
fast_rnnt::PybindMutualInformation(m);
fast_rnnt::PybindUtils(m);
}
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
#define FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
#include "pybind11/pybind11.h"
namespace py = pybind11;
#endif // FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
namespace fast_rnnt {
void PybindMutualInformation(py::module &m) {
m.def(
"mutual_information_forward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p) -> torch::Tensor {
fast_rnnt::DeviceGuard guard(px.device());
if (px.device().is_cpu()) {
return MutualInformationCpu(px, py, boundary, p);
} else {
#ifdef FT_WITH_CUDA
return MutualInformationCuda(px, py, boundary, p);
#else
TORCH_CHECK(false, "Failed to find native CUDA module, make sure "
"that you compiled the code with K2_WITH_CUDA.");
return torch::Tensor();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"));
m.def(
"mutual_information_backward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary, torch::Tensor p,
torch::Tensor ans_grad) -> std::vector<torch::Tensor> {
fast_rnnt::DeviceGuard guard(px.device());
if (px.device().is_cpu()) {
return MutualInformationBackwardCpu(px, py, boundary, p, ans_grad);
} else {
#ifdef FT_WITH_CUDA
return MutualInformationBackwardCuda(px, py, boundary, p, ans_grad,
true);
#else
TORCH_CHECK(false, "Failed to find native CUDA module, make sure "
"that you compiled the code with K2_WITH_CUDA.");
return std::vector<torch::Tensor>();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"),
py::arg("ans_grad"));
}
} // namespace fast_rnnt
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace fast_rnnt {
void PybindMutualInformation(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/utils.h"
#include "fast_rnnt/python/csrc/utils.h"
namespace fast_rnnt {
void PybindUtils(py::module &m) {
m.def(
"monotonic_lower_bound_",
[](torch::Tensor &src) -> void {
DeviceGuard guard(src.device());
if (src.dim() == 1) {
MonotonicLowerBound(src);
} else if (src.dim() == 2) {
int32_t dim0 = src.sizes()[0];
for (int32_t i = 0; i < dim0; ++i) {
auto sub = src.index({i});
MonotonicLowerBound(sub);
}
} else {
TORCH_CHECK(false,
"Only support 1 dimension and 2 dimensions tensor");
}
},
py::arg("src"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
}
} // namespace fast_rnnt
/**
* @brief python wrappers for utils.h
*
* @copyright
* Copyright 2022 Xiaomi Corp. (author: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_UTILS_H_
#define FAST_RNNT_PYTHON_CSRC_UTILS_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace fast_rnnt {
void PybindUtils(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_UTILS_H_
from _fast_rnnt import monotonic_lower_bound_
from _fast_rnnt import with_cuda
from .mutual_information import mutual_information_recursion
from .mutual_information import joint_mutual_information_recursion
from .rnnt_loss import do_rnnt_pruning
from .rnnt_loss import get_rnnt_logprobs
from .rnnt_loss import get_rnnt_logprobs_joint
from .rnnt_loss import get_rnnt_logprobs_pruned
from .rnnt_loss import get_rnnt_logprobs_smoothed
from .rnnt_loss import get_rnnt_prune_ranges
from .rnnt_loss import rnnt_loss
from .rnnt_loss import rnnt_loss_pruned
from .rnnt_loss import rnnt_loss_simple
from .rnnt_loss import rnnt_loss_smoothed
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey, Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import _fast_rnnt
from torch import Tensor
from typing import Tuple, Optional, Sequence, Union, List
class MutualInformationRecursionFunction(torch.autograd.Function):
"""A recursion that is useful in computing mutual information between two
sequences of real vectors, but may be useful more generally in
sequence-to-sequence tasks where monotonic alignment between pairs of
sequences is desired.
"""
@staticmethod
def forward(
ctx,
px: torch.Tensor,
py: torch.Tensor,
pxy_grads: List[Optional[torch.Tensor]],
boundary: Optional[torch.Tensor] = None,
return_grad: bool = False,
) -> torch.Tensor:
"""
Computing mutual information between two sequences of real vectors.
Args:
px:
A torch.Tensor of some floating point type, with shape
``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the
length of the ``x`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols), and ``S`` is the
length of the ``y`` sequence (including representations of
``EOS`` symbols but not ``BOS`` symbols). In the mutual
information application, ``px[b][s][t]`` would represent the
following log odds ratio; ignoring the b index on the right
to make the notation more
compact::
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an ``x`` value as opposed to a ``y`` value. In
practice it might be computed as ``a + b``, where ``a`` is the log
probability of choosing to extend the sequence of length ``(s,t)``
with an ``x`` as opposed to a ``y`` value; and ``b`` might in
practice be of the form::
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where ``N`` is the number of terms that the sum over ``t'``
included, which might include some or all of the other sequences as
well as this one.
Note:
we don't require ``px`` and py to be contiguous, but the
code assumes for optimization purposes that the ``T`` axis has
stride 1.
py:
A torch.Tensor of the same dtype as ``px``, with shape
``[B][S+1][T]``, representing::
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat ``x`` and ``y`` differently; the only
difference is that for optimization purposes we assume the last axis
(the ``t`` axis) has stride of 1; this is true if ``px`` and ``py``
are contiguous.
pxy_grads:
A List to store the return grads of ``px`` and ``py``
if return_grad == True.
Remain unchanged if return_grad == False.
See `this PR <https://github.com/k2-fsa/k2/pull/924>` for more
information about why we add this parameter.
Note:
the length of the list must be 2, where the first element
represents the grads of ``px`` and the second one represents
the grads of ``py``.
boundary:
If supplied, a torch.LongTensor of shape ``[B][4]``, where each
row contains ``[s_begin, t_begin, s_end, t_end]``,
with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T``
(this implies that empty sequences are allowed).
If not supplied, the values ``[0, 0, S, T]`` will be assumed.
These are the beginning and one-past-the-last positions in the
``x`` and ``y`` sequences respectively, and can be used if not
all sequences are
of the same length.
return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing
for the occupation probability is the output of the backward with a
``fake gradient`` the ``fake gradient`` is the same as the gradient
you'd get if you did
``torch.autograd.grad((scores.sum()), [px, py])``.
This is useful to implement the pruned version of rnnt loss.
Returns:
Returns a torch.Tensor of shape ``[B]``, containing the log of
the mutual information between the b'th pair of sequences. This is
defined by the following recursion on ``p[b,s,t]`` (where ``p``
is of shape ``[B,S+1,T+1]``), representing a mutual information
between sub-sequences of lengths ``s`` and ``t``::
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
where we handle edge cases by treating quantities with negative
indexes as **-infinity**. The extension to cases where the
boundaries are specified should be obvious; it just works on
shorter sequences with offsets into ``px`` and ``py``.
"""
(B, S, T1) = px.shape
T = py.shape[-1]
assert T1 in [T, T + 1]
assert py.shape == (B, S + 1, T)
if boundary is not None:
assert boundary.shape == (B, 4)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# the mutual information of the pair of subsequences of x and y that
# are of length s and t respectively. p[0][0] will be 0.0 and p[S][T]
# is the mutual information of the entire pair of sequences,
# i.e. of lengths S and T respectively.
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _fast_rnnt.mutual_information_forward(px, py, boundary, p)
px_grad, py_grad = None, None
if return_grad or px.requires_grad or py.requires_grad:
ans_grad = torch.ones(B, device=px.device, dtype=px.dtype)
(px_grad, py_grad) = _fast_rnnt.mutual_information_backward(
px, py, boundary, p, ans_grad)
ctx.save_for_backward(px_grad, py_grad)
assert len(pxy_grads) == 2
pxy_grads[0] = px_grad
pxy_grads[1] = py_grad
return ans
@staticmethod
def backward(
ctx, ans_grad: Tensor
) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]:
(px_grad, py_grad) = ctx.saved_tensors
(B,) = ans_grad.shape
ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1)
px_grad *= ans_grad
py_grad *= ans_grad
return (px_grad, py_grad, None, None, None)
def mutual_information_recursion(
px: Tensor,
py: Tensor,
boundary: Optional[Tensor] = None,
return_grad: bool = False,
) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]:
"""A recursion that is useful in computing mutual information between two
sequences of real vectors, but may be useful more generally in
sequence-to-sequence tasks where monotonic alignment between pairs of
sequences is desired. The definitions of the arguments are definitions that
would be used when computing this type of mutual information, but you can
also view them as arbitrary quantities and just make use of the formula
computed by this function.
Args:
px:
A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``,
where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence
(including representations of ``EOS`` symbols but not ``BOS`` symbols),
and ``S`` is the length of the ``y`` sequence (including representations
of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information
application, ``px[b][s][t]`` would represent the following log odds
ratio; ignoring the b index on the right to make the notation more
compact::
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an ``x`` value as opposed to a ``y`` value. In
practice it might be computed as ``a + b``, where ``a`` is the log
probability of choosing to extend the sequence of length ``(s,t)``
with an ``x`` as opposed to a ``y`` value; and ``b`` might in practice
be of the form::
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where ``N`` is the number of terms that the sum over ``t'`` included,
which might include some or all of the other sequences as well as this
one.
Note:
we don't require ``px`` and py to be contiguous, but the
code assumes for optimization purposes that the ``T`` axis has
stride 1.
py:
A torch.Tensor of the same dtype as ``px``, with shape ``[B][S+1][T]``,
representing::
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat ``x`` and ``y`` differently; the only
difference is that for optimization purposes we assume the last axis
(the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` are
contiguous.
boundary:
If supplied, a torch.LongTensor of shape ``[B][4]``, where each
row contains ``[s_begin, t_begin, s_end, t_end]``,
with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T``
(this implies that empty sequences are allowed).
If not supplied, the values ``[0, 0, S, T]`` will be assumed.
These are the beginning and one-past-the-last positions in the ``x`` and
``y`` sequences respectively, and can be used if not all sequences are
of the same length.
return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing for the
occupation probability is the output of the backward with a
``fake gradient`` the ``fake gradient`` is the same as the gradient
you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``.
This is useful to implement the pruned version of rnnt loss.
Returns:
Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual
information between the b'th pair of sequences. This is defined by
the following recursion on ``p[b,s,t]`` (where ``p`` is of shape
``[B,S+1,T+1]``), representing a mutual information between sub-sequences
of lengths ``s`` and ``t``::
p[b,0,0] = 0.0
if !modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
where we handle edge cases by treating quantities with negative indexes
as **-infinity**. The extension to cases where the boundaries are
specified should be obvious; it just works on shorter sequences with
offsets into ``px`` and ``py``.
"""
assert px.ndim == 3
B, S, T1 = px.shape
T = py.shape[-1]
assert px.shape[-1] in [T, T + 1] # if T, then "modified".
assert py.shape == (B, S + 1, T)
assert px.dtype == py.dtype
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for s_begin, t_begin, s_end, t_end in boundary.tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
# The following assertions are for efficiency
assert px.is_contiguous()
assert py.is_contiguous()
pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad)
px_grad, py_grad = pxy_grads
return (scores, (px_grad, py_grad)) if return_grad else scores
def _inner_product(a: Tensor, b: Tensor) -> Tensor:
"""
Does inner product on the last dimension, with expected broadcasting,
i.e. equivalent to (a * b).sum(dim=-1)
without creating a large temporary.
"""
assert a.shape[-1] == b.shape[-1] # The last dim must be equal
a = a.unsqueeze(-2) # (..., 1, K)
b = b.unsqueeze(-1) # (..., K, 1)
c = torch.matmul(a, b) # (..., 1, 1)
return c.squeeze(-1).squeeze(-1)
def joint_mutual_information_recursion(
px: Sequence[Tensor],
py: Sequence[Tensor],
boundary: Optional[Tensor] = None,
) -> Sequence[Tensor]:
"""A recursion that is useful for modifications of RNN-T and similar loss
functions, where the recursion probabilities have a number of terms and you
want them reported separately. See mutual_information_recursion() for more
documentation of the basic aspects of this.
Args:
px:
a sequence of Tensors, each of the same shape [B][S][T+1]
py:
a sequence of Tensor, each of the same shape [B][S+1][T],
the sequence must be the same length as px.
boundary:
optionally, a LongTensor of shape [B][4] containing rows
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S
and 0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
These are the beginning and one-past-the-last positions in the x
and y sequences respectively, and can be used if not all
sequences are of the same length.
Returns:
a Tensor of shape (len(px), B),
whose sum over dim 0 is the total log-prob of the recursion mentioned
below, per sequence. The first element of the sequence of length len(px)
is "special", in that it has an offset term reflecting the difference
between sum-of-log and log-of-sum; for more interpretable loss values,
the "main" part of your loss function should be first.
The recursion below applies if boundary == None, when it defaults
to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px
and py::
p = tensor of shape (B, S+1, T+1), containing -infinity
p[b,0,0] = 0.0
# do the following in loop over s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
p[b,s,t-1] + py_sum[b,s,t-1])
(if s > 0 or t > 0)
return b[:][S][T]
This function lets you implement the above recursion efficiently, except
that it gives you a breakdown of the contribution from all the elements of
px and py separately. As noted above, the first element of the
sequence is "special".
"""
N = len(px)
assert len(py) == N and N > 0
B, S, T1 = px[0].shape
T = py[0].shape[2]
assert T1 in [T, T + 1] # T if modified...
assert py[0].shape == (B, S + 1, T)
assert px[0].dtype == py[0].dtype
px_cat = torch.stack(
px, dim=0
) # (N, B, S, T+1) if !modified,(N, B, S, T) if modified.
py_cat = torch.stack(py, dim=0) # (N, B, S+1, T)
px_tot = px_cat.sum(dim=0) # (B, S, T+1)
py_tot = py_cat.sum(dim=0) # (B, S+1, T)
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for s_begin, t_begin, s_end, t_end in boundary.tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
# The following assertions are for efficiency
assert px_tot.ndim == 3
assert py_tot.ndim == 3
p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)
# note, tot_probs is without grad.
tot_probs = _fast_rnnt.mutual_information_forward(px_tot, py_tot, boundary, p)
# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)
(px_grad,
py_grad) = _fast_rnnt.mutual_information_backward(px_tot, py_tot, boundary, p,
ans_grad)
px_grad = px_grad.reshape(1, B, -1)
py_grad = py_grad.reshape(1, B, -1)
px_cat = px_cat.reshape(N, B, -1)
py_cat = py_cat.reshape(N, B, -1)
# get rid of -inf, would generate nan on product with 0
px_cat = px_cat.clamp(min=torch.finfo(px_cat.dtype).min)
py_cat = py_cat.clamp(min=torch.finfo(py_cat.dtype).min)
x_prods = _inner_product(px_grad, px_cat) # (N, B)
y_prods = _inner_product(py_grad, py_cat) # (N, B)
# If all the occupation counts were exactly 1.0 (i.e. no partial counts),
# "prods" should be equal to "tot_probs"; however, in general, "tot_probs"
# will be more positive due to the difference between log-of-sum and
# sum-of-log
prods = x_prods + y_prods # (N, B)
with torch.no_grad():
offset = tot_probs - prods.sum(dim=0) # (B,)
prods[0] += offset
return prods # (N, B)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment