Unverified Commit a283cddf authored by Daniel Povey's avatar Daniel Povey Committed by GitHub
Browse files

Merge pull request #4 from pkufool/fast_rnnt

Sync with k2 rnnt_loss
parents b5828e2b 182fe8de
include_directories(${CMAKE_SOURCE_DIR})
# it is located in fast_rnnt/cmake/transform.cmake
include(transform)
set(srcs
mutual_information_cpu.cu
utils.cu
)
if(NOT FT_WITH_CUDA)
transform(OUTPUT_VARIABLE srcs SRCS ${srcs})
else()
list(APPEND srcs mutual_information_cuda.cu)
endif()
add_library(mutual_information_core ${srcs})
target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES})
# for <torch/extension.h>
target_include_directories(mutual_information_core PUBLIC ${PYTHON_INCLUDE_DIRS})
# see https://github.com/NVIDIA/thrust/issues/1401
# and https://github.com/k2-fsa/k2/pull/917
target_compile_definitions(mutual_information_core PUBLIC CUB_WRAPPED_NAMESPACE=fast_rnnt)
target_compile_definitions(mutual_information_core PUBLIC THRUST_NS_QUALIFIER=thrust)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
target_link_libraries(mutual_information_core PUBLIC cub)
endif()
/**
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang, Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_DEVICE_GUARD_H_
#define FAST_RNNT_CSRC_DEVICE_GUARD_H_
#include "torch/script.h"
// This file is modified from
// https://github.com/k2-fsa/k2/blob/master/k2/csrc/device_guard.h
namespace fast_rnnt {
// DeviceGuard is an RAII class. Its sole purpose is to restore
// the previous default cuda device if a CUDA context changes the
// current default cuda device.
class DeviceGuard {
public:
explicit DeviceGuard(torch::Device device) {
if (device.type() == torch::kCUDA) {
old_device_ = GetDevice();
new_device_ = device.index();
if (old_device_ != new_device_)
SetDevice(new_device_);
}
// else do nothing
}
explicit DeviceGuard(int32_t new_device) : new_device_(new_device) {
if (new_device != -1) {
old_device_ = GetDevice();
if (old_device_ != new_device)
SetDevice(new_device);
}
}
~DeviceGuard() {
if (old_device_ != -1 && old_device_ != new_device_) {
// restore the previous device
SetDevice(old_device_);
}
// else it was either a CPU context or the device IDs
// were the same
}
DeviceGuard(const DeviceGuard &) = delete;
DeviceGuard &operator=(const DeviceGuard &) = delete;
DeviceGuard(DeviceGuard &&) = delete;
DeviceGuard &operator=(DeviceGuard &&) = delete;
private:
static int32_t GetDevice() {
#ifdef FT_WITH_CUDA
int32_t device;
auto s = cudaGetDevice(&device);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
return device;
#else
return -1;
#endif
}
static void SetDevice(int32_t device) {
#ifdef FT_WITH_CUDA
auto s = cudaSetDevice(device);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
return;
#endif
}
private:
int32_t old_device_ = -1;
int32_t new_device_ = -1;
};
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_DEVICE_GUARD_H_
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#include <cmath>
#include <vector>
#include "torch/extension.h"
#ifdef __CUDA_ARCH__
#define FT_CUDA_HOSTDEV __host__ __device__
#else
#define FT_CUDA_HOSTDEV
#endif
namespace fast_rnnt {
FT_CUDA_HOSTDEV inline double LogAdd(double x, double y) {
double diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
// returns log(exp(x) + exp(y)).
FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
/*
Forward of mutual_information. See also comment of `mutual_information`
in ../pyhton/fast_rnnt/mutual_information.py. This is the core recursion
in the sequence-to-sequence mutual information computation.
@param px Tensor of shape [B][S][T + 1] if not modified, [B][S][T] if
modified. `modified` can be worked out from this. In not-modified case,
it can be thought of as the log-odds ratio of generating the next x in
the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of
lengths (s, t), divided by the prior probability of generating x_s.
(See mutual_information.py for more info).
@param py The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
@param p This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
if not modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
... treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b]
equals [s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last)
of the x and y sequences that we should process.
Alternatively, may be a tensor of shape [0][0] and type
int64_t; the elements will default to (0, 0, S, T).
@return A tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor MutualInformationCpu(
torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
torch::Tensor MutualInformationCuda(
torch::Tensor px, // [B][S][T+1] if !modified, [B][S][T] if modified.
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor>
MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad);
std::vector<torch::Tensor> MutualInformationBackwardCuda(
torch::Tensor px, torch::Tensor py, torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "fast_rnnt/csrc/mutual_information.h"
namespace fast_rnnt {
// forward of mutual_information. See """... """ comment of
// `mutual_information_recursion` in
// in k2/python/k2/mutual_information.py for documentation of the
// behavior of this function.
// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
// `modified` from this.
// py: of shape [B, S+1, T]
// boundary: of shape [B, 4], containing (s_begin, t_begin, s_end, t_end)
// defaulting to (0, 0, S, T).
// p: of shape (S+1, T+1)
// Computes the recursion:
// if !modified:
// p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
// p[b,s,t-1] + py[b,s,t-1])
// if modified:
// p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
// p[b,s,t-1] + py[b,s,t-1])
// .. treating out-of-range elements as -infinity and with special cases:
// p[b, s_begin, t_begin] = 0.0
//
// and this function returns a tensor of shape (B,) consisting of elements
// p[b, s_end, t_end]
torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
p.device().is_cpu(),
"inputs must be CPU tensors");
bool modified = (px.size(2) == py.size(2));
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), S = px.size(1), T = py.size(2);
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, T},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts);
AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(),
py_a = py.accessor<scalar_t, 3>(), p_a = p.accessor<scalar_t, 3>();
auto boundary_a = boundary.accessor<int64_t, 2>();
auto ans_a = ans.accessor<scalar_t, 1>();
int t_offset = (modified ? -1 : 0);
for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int t_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int t_end = boundary_a[b][3];
p_a[b][s_begin][t_begin] = 0.0;
if (modified) {
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] = -std::numeric_limits<scalar_t>::infinity();
} else {
// note: t_offset = 0 so don't need t_begin + t_offset below.
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] =
p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
}
for (int t = t_begin + 1; t <= t_end; ++t)
p_a[b][s_begin][t] =
p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
for (int s = s_begin + 1; s <= s_end; ++s) {
scalar_t p_s_t1 = p_a[b][s][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t) {
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(
// p_a[b][s - 1][t + t_offset] + px_a[b][s -1][t + t_offset],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a[b][s][t] = p_s_t1 = LogAdd(p_a[b][s - 1][t + t_offset] +
px_a[b][s - 1][t + t_offset],
p_s_t1 + py_a[b][s][t - 1]);
}
}
ans_a[b] = p_a[b][s_end][t_end];
}
}));
return ans;
}
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std::vector<torch::Tensor>
MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary,
torch::Tensor p, torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional.");
bool modified = (px.size(2) == py.size(2));
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
p.device().is_cpu() && ans_grad.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), S = px.size(1), T = py.size(2);
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, T},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
bool has_boundary = opt_boundary.has_value();
int T1 = T + (modified ? 0 : 1);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts)
: torch::empty({B, S, T1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts)
: torch::empty({B, S + 1, T}, opts));
AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(), p_a = p.accessor<scalar_t, 3>(),
p_grad_a = p_grad.accessor<scalar_t, 3>(),
px_grad_a = px_grad.accessor<scalar_t, 3>(),
py_grad_a = py_grad.accessor<scalar_t, 3>();
auto ans_grad_a = ans_grad.accessor<scalar_t, 1>();
auto boundary_a = boundary.accessor<int64_t, 2>();
int t_offset = (modified ? -1 : 0);
for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int t_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int t_end = boundary_a[b][3];
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a[b][s_end][t_end] = ans_grad_a[b];
for (int s = s_end; s > s_begin; --s) {
for (int t = t_end; t > t_begin; --t) {
// The s,t indexes correspond to
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(
// p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t + t_offset] +
px_a[b][s - 1][t + t_offset],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total = p_a[b][s][t];
if (total - total != 0)
total = 0;
scalar_t term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t];
scalar_t term1_grad, term2_grad;
if (term1_deriv - term1_deriv == 0.0) {
term1_grad = term1_deriv * grad;
term2_grad = term2_deriv * grad;
} else {
// could happen if total == -inf
term1_grad = term2_grad = 0.0;
}
px_grad_a[b][s - 1][t + t_offset] = term1_grad;
p_grad_a[b][s - 1][t + t_offset] = term1_grad;
py_grad_a[b][s][t - 1] = term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
}
}
for (int t = t_end; t > t_begin; --t) {
// Backprop for:
// p_a[b][s_begin][t] =
// p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t this_p_grad = p_grad_a[b][s_begin][t];
p_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] = this_p_grad;
}
if (!modified) {
for (int s = s_end; s > s_begin; --s) {
// Backprop for:
// p_a[b][s][t_begin] =
// p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t this_p_grad = p_grad_a[b][s][t_begin];
p_grad_a[b][s - 1][t_begin] += this_p_grad;
px_grad_a[b][s - 1][t_begin] = this_p_grad;
}
} // else these were all -infinity's and there is nothing to
// backprop.
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b];
if (fabs(grad_ratio - 1.0) > 0.01) {
std::cout
<< "Warning: mutual_information backprop: expected these "
<< "numbers to be the same:"
<< static_cast<float>(p_grad_a[b][s_begin][t_begin]) << " vs "
<< static_cast<float>(ans_grad_a[b]);
}
}
}
}));
return std::vector<torch::Tensor>({px_grad, py_grad});
}
} // namespace fast_rnnt
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/utils.h"
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src) {
TORCH_CHECK(src.dim() == 1, "Only support one dimension tensor");
TORCH_CHECK(src.scalar_type() == torch::kLong, "Only support LongTensor");
TORCH_CHECK(src.is_contiguous(), "Expected to be contiguous");
int32_t dim = src.numel();
if (src.device().type() == torch::kCPU) {
int64_t min_value = std::numeric_limits<int64_t>::max();
int64_t *src_data = src.data_ptr<int64_t>();
for (int32_t i = dim - 1; i >= 0; --i) {
min_value = std::min(src_data[i], min_value);
src[i] = min_value;
}
} else {
#ifdef FT_WITH_CUDA
TORCH_CHECK(src.device().is_cuda());
internal::MinOp<int64_t> min_op;
auto src_data = src.data_ptr<int64_t>();
internal::ConstReversedPtr<int64_t> src_ptr =
internal::ConstReversedPtr<int64_t>(src_data, dim);
internal::ReversedPtr<int64_t> dest_ptr =
internal::ReversedPtr<int64_t>(src_data, dim);
// The first time is to determine temporary device storage requirements.
std::size_t temp_storage_bytes = 0;
auto s = cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
auto d_temp = torch::empty({static_cast<int64_t>(temp_storage_bytes)},
torch::dtype(torch::kInt8).device(src.device()));
int8_t *d_temp_storage = d_temp.data_ptr<int8_t>();
s = cub::DeviceScan::InclusiveScan(d_temp_storage, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
TORCH_CHECK(false, "Please build with -DFT_WITH_CUDA=ON");
#endif // FT_WITH_CUDA
}
}
} // namespace fast_rnnt
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_UTILS_H_
#define FAST_RNNT_CSRC_UTILS_H_
#include "torch/extension.h"
#ifdef FT_WITH_CUDA
#include "cub/cub.cuh" // NOLINT
namespace fast_rnnt {
namespace internal {
template <typename T> struct MinOp {
__host__ __device__ __forceinline__ T operator()(const T &a,
const T &b) const {
return (a > b) ? b : a;
}
};
// Will be used (as both InputIterator and OutputIterator) in
// MonotonicLowerBound to call cub::DeviceScan::InclusiveScan.
template <typename T> struct ConstReversedPtr {
const T *data;
// data points to the last element now
explicit ConstReversedPtr(const T *data, int32_t size)
: data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ const T &operator[](int32_t i) const {
return data[-i];
}
__host__ __device__ __forceinline__ ConstReversedPtr
operator+(int32_t n) const {
ConstReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ const T &operator*() const {
return *data;
}
};
template <typename T> struct ReversedPtr {
T *data;
// data points to the last element now
explicit ReversedPtr(T *data, int32_t size) : data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ T &operator[](int32_t i) {
return data[-i];
}
__host__ __device__ __forceinline__ ReversedPtr operator+(int32_t n) const {
ReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ T &operator*() { return *data; }
};
} // namespace internal
} // namespace fast_rnnt
namespace std {
// vaule_type is required by cub::DeviceScan::InclusiveSum
template <typename T>
struct iterator_traits<fast_rnnt::internal::ConstReversedPtr<T>> {
typedef T value_type;
};
template <typename T>
struct iterator_traits<fast_rnnt::internal::ReversedPtr<T>> {
typedef T value_type;
};
} // namespace std
#endif // FT_WITH_CUDA
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_UTILS_H_
add_subdirectory(csrc)
add_subdirectory(tests)
include_directories(${CMAKE_SOURCE_DIR})
include(transform)
# please keep the list sorted
set(fast_rnnt_srcs
fast_rnnt.cu
mutual_information.cu
utils.cu
)
if(NOT FT_WITH_CUDA)
transform(OUTPUT_VARIABLE fast_rnnt_srcs SRCS ${fast_rnnt_srcs})
endif()
pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs})
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)
if(UNIX AND NOT APPLE)
target_link_libraries(_fast_rnnt
PRIVATE
${PYTHON_LIBRARY}
${TORCH_DIR}/lib/libtorch_python.so
)
endif()
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/utils.h"
PYBIND11_MODULE(_fast_rnnt, m) {
m.doc() = "Python wrapper for Fast Rnnt.";
fast_rnnt::PybindMutualInformation(m);
fast_rnnt::PybindUtils(m);
}
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
#define FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
#include "pybind11/pybind11.h"
namespace py = pybind11;
#endif // FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
namespace fast_rnnt {
void PybindMutualInformation(py::module &m) {
m.def(
"mutual_information_forward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p) -> torch::Tensor {
fast_rnnt::DeviceGuard guard(px.device());
if (px.device().is_cpu()) {
return MutualInformationCpu(px, py, boundary, p);
} else {
#ifdef FT_WITH_CUDA
return MutualInformationCuda(px, py, boundary, p);
#else
TORCH_CHECK(false, "Failed to find native CUDA module, make sure "
"that you compiled the code with K2_WITH_CUDA.");
return torch::Tensor();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"));
m.def(
"mutual_information_backward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary, torch::Tensor p,
torch::Tensor ans_grad) -> std::vector<torch::Tensor> {
fast_rnnt::DeviceGuard guard(px.device());
if (px.device().is_cpu()) {
return MutualInformationBackwardCpu(px, py, boundary, p, ans_grad);
} else {
#ifdef FT_WITH_CUDA
return MutualInformationBackwardCuda(px, py, boundary, p, ans_grad,
true);
#else
TORCH_CHECK(false, "Failed to find native CUDA module, make sure "
"that you compiled the code with K2_WITH_CUDA.");
return std::vector<torch::Tensor>();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"),
py::arg("ans_grad"));
}
} // namespace fast_rnnt
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace fast_rnnt {
void PybindMutualInformation(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/utils.h"
#include "fast_rnnt/python/csrc/utils.h"
namespace fast_rnnt {
void PybindUtils(py::module &m) {
m.def(
"monotonic_lower_bound_",
[](torch::Tensor &src) -> void {
DeviceGuard guard(src.device());
if (src.dim() == 1) {
MonotonicLowerBound(src);
} else if (src.dim() == 2) {
int32_t dim0 = src.sizes()[0];
for (int32_t i = 0; i < dim0; ++i) {
auto sub = src.index({i});
MonotonicLowerBound(sub);
}
} else {
TORCH_CHECK(false,
"Only support 1 dimension and 2 dimensions tensor");
}
},
py::arg("src"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
}
} // namespace fast_rnnt
/**
* @brief python wrappers for utils.h
*
* @copyright
* Copyright 2022 Xiaomi Corp. (author: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_UTILS_H_
#define FAST_RNNT_PYTHON_CSRC_UTILS_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace fast_rnnt {
void PybindUtils(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_UTILS_H_
from _fast_rnnt import monotonic_lower_bound_
from _fast_rnnt import with_cuda
from .mutual_information import mutual_information_recursion
from .mutual_information import joint_mutual_information_recursion
from .rnnt_loss import do_rnnt_pruning
from .rnnt_loss import get_rnnt_logprobs
from .rnnt_loss import get_rnnt_logprobs_joint
from .rnnt_loss import get_rnnt_logprobs_pruned
from .rnnt_loss import get_rnnt_logprobs_smoothed
from .rnnt_loss import get_rnnt_prune_ranges
from .rnnt_loss import rnnt_loss
from .rnnt_loss import rnnt_loss_pruned
from .rnnt_loss import rnnt_loss_simple
from .rnnt_loss import rnnt_loss_smoothed
This diff is collapsed.
This diff is collapsed.
function(fast_rnnt_add_py_test source)
get_filename_component(name ${source} NAME_WE)
set(name "${name}_py")
add_test(NAME ${name}
COMMAND
"${PYTHON_EXECUTABLE}"
"${CMAKE_CURRENT_SOURCE_DIR}/${source}"
)
get_filename_component(fast_rnnt_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
set_property(TEST ${name}
PROPERTY ENVIRONMENT "PYTHONPATH=${fast_rnnt_path}:$<TARGET_FILE_DIR:_fast_rnnt>:$ENV{PYTHONPATH}"
)
endfunction()
# please sort the files in alphabetic order
set(py_test_files
mutual_information_test.py
rnnt_loss_test.py
)
foreach(source IN LISTS py_test_files)
fast_rnnt_add_py_test(${source})
endforeach()
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (authors: Daniel Povey,
# Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run this single test, use
#
# ctest --verbose -R mutual_information_test_py
import random
import unittest
import fast_rnnt
import torch
# Caution: this will fail occasionally due to cutoffs not being quite large
# enough. As long as it passes most of the time, it's OK.
class TestMutualInformation(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.devices = [torch.device("cpu")]
if torch.cuda.is_available() and fast_rnnt.with_cuda():
cls.devices.append(torch.device("cuda", 0))
if torch.cuda.device_count() > 1:
torch.cuda.set_device(1)
cls.devices.append(torch.device("cuda", 1))
cls.dtypes = [torch.float32, torch.float64]
def test_mutual_information_basic(self):
for _iter in range(10):
(B, S, T) = (
random.randint(1, 10),
random.randint(1, 16),
random.randint(1, 500),
)
random_px = random.random() < 0.2
random_py = random.random() < 0.2
random_boundary = random.random() < 0.7
big_px = random.random() < 0.2
big_py = random.random() < 0.2
modified = random.random() < 0.5
if modified and T < S:
T = S + random.randint(0, 30)
for dtype in self.dtypes:
for device in self.devices:
if random_boundary:
def get_boundary_row():
this_S = random.randint(
0, S
) # allow empty sequence
this_T = random.randint(
this_S if modified else 1, T
)
s_begin = random.randint(0, S - this_S)
t_begin = random.randint(0, T - this_T)
s_end = s_begin + this_S
t_end = t_begin + this_T
return [s_begin, t_begin, s_end, t_end]
if device == torch.device("cpu"):
boundary = torch.tensor(
[get_boundary_row() for _ in range(B)],
dtype=torch.int64,
device=device,
)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly
# or not.
if random.random() < 0.5:
boundary = (
torch.tensor([0, 0, S, T], dtype=torch.int64)
.unsqueeze(0)
.expand(B, 4)
.to(device)
)
else:
boundary = None
if device == torch.device("cpu"):
if random_px:
# log of an odds ratio
px = torch.randn(
B, S, T + (0 if modified else 1), dtype=dtype
).to(device)
if S > 1 and not random_boundary and not modified:
px[:, :, -1:] = float("-inf")
else:
# log of an odds ratio
px = torch.zeros(
B, S, T + (0 if modified else 1), dtype=dtype
).to(device)
# px and py get exponentiated, and then multiplied
# together up to 32 times (BLOCK_SIZE in the CUDA code),
# so 15 is actually a big number that could lead to
# overflow.
if big_px:
px += 15.0
if random_py:
# log of an odds ratio
py = torch.randn(B, S + 1, T, dtype=dtype).to(
device
)
else:
# log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(
device
)
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
m = fast_rnnt.mutual_information_recursion(px, py, boundary)
m2 = fast_rnnt.joint_mutual_information_recursion(
(px,), (py,), boundary
)
m3 = fast_rnnt.joint_mutual_information_recursion(
(px * 0.5, px * 0.5), (py * 0.5, py * 0.5), boundary
)
# it is supposed to be identical only after
# summing over dim 0, corresponding to the
# sequence dim
m3 = m3.sum(dim=0)
assert torch.allclose(m, m2)
assert torch.allclose(m, m3)
# the loop this is in checks that the CPU and CUDA versions
# give the same derivative;
# by randomizing which of m, m2 or m3 we backprop, we also
# ensure that the joint version of the code gives the same
# derivative as the regular version
scale = 3
if random.random() < 0.5:
(m.sum() * scale).backward()
elif random.random() < 0.5:
(m2.sum() * scale).backward()
else:
(m3.sum() * scale).backward()
if device == torch.device("cpu"):
expected_px_grad = px.grad
expected_py_grad = py.grad
expected_m = m
assert torch.allclose(
px.grad,
expected_px_grad.to(device),
atol=1.0e-02,
rtol=1.0e-02,
)
assert torch.allclose(
py.grad,
expected_py_grad.to(device),
atol=1.0e-02,
rtol=1.0e-02,
)
assert torch.allclose(
m, expected_m.to(device), atol=1.0e-02, rtol=1.0e-02
)
def test_mutual_information_deriv(self):
for _iter in range(10):
(B, S, T) = (
random.randint(1, 100),
random.randint(1, 200),
random.randint(1, 200),
)
random_px = random.random() < 0.2
random_py = random.random() < 0.2
random_boundary = random.random() < 0.7
big_px = random.random() < 0.2
big_py = random.random() < 0.2
modified = random.random() < 0.5
if modified and T < S:
T = S + random.randint(0, 30)
for dtype in self.dtypes:
for device in self.devices:
if random_boundary:
def get_boundary_row():
this_S = random.randint(1, S)
this_T = random.randint(
this_S if modified else 1, T
)
s_begin = random.randint(0, S - this_S)
t_begin = random.randint(0, T - this_T)
s_end = s_begin + this_S
t_end = t_begin + this_T
return [s_begin, t_begin, s_end, t_end]
if device == torch.device("cpu"):
boundary = torch.tensor(
[get_boundary_row() for _ in range(B)],
dtype=torch.int64,
device=device,
)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly
# or not.
if random.random() < 0.5:
boundary = (
torch.tensor([0, 0, S, T], dtype=torch.int64)
.unsqueeze(0)
.expand(B, 4)
.to(device)
)
else:
boundary = None
T1 = T + (0 if modified else 1)
if device == torch.device("cpu"):
if random_px:
# log of an odds ratio
px = torch.randn(B, S, T1, dtype=dtype).to(device)
else:
# log of an odds ratio
px = torch.zeros(B, S, T1, dtype=dtype).to(device)
# px and py get exponentiated, and then multiplied
# together up to 32 times (BLOCK_SIZE in the CUDA code),
# so 15 is actually a big number that could lead to
# overflow.
if big_px:
px += 15.0
if random_py:
# log of an odds ratio
py = torch.randn(B, S + 1, T, dtype=dtype).to(
device
)
else:
# log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(
device
)
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
m = fast_rnnt.mutual_information_recursion(px, py, boundary)
m_grad = torch.randn(B, dtype=dtype, device=device)
m.backward(gradient=m_grad)
delta = 1.0e-04
delta_px = delta * torch.randn_like(px)
m2 = fast_rnnt.mutual_information_recursion(
px + delta_px, py, boundary
)
delta_m = m2 - m
observed_delta = (delta_m * m_grad).sum().to("cpu")
predicted_delta = (delta_px * px.grad).sum().to("cpu")
atol = 1.0e-02 if dtype == torch.float32 else 1.0e-04
rtol = 1.0e-02 if dtype == torch.float32 else 1.0e-04
assert torch.allclose(
observed_delta, predicted_delta, atol=atol, rtol=rtol
)
delta_py = delta * torch.randn_like(py)
m2 = fast_rnnt.mutual_information_recursion(
px, py + delta_py, boundary
)
delta_m = m2 - m
observed_delta = (delta_m * m_grad).sum().to("cpu")
predicted_delta = (delta_py * py.grad).sum().to("cpu")
assert torch.allclose(
observed_delta, predicted_delta, atol=atol, rtol=rtol
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment