Commit b32f8a26 authored by pkufool's avatar pkufool
Browse files

Replace monotonic_lower_bound with efficient implementation; remove cub dependancy

parent b0ed23ef
...@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) ...@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include(pybind11) include(pybind11)
include(torch) include(torch)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
# CUB is included in CUDA toolkit 11.0 and above
include(cub)
endif()
if(FT_BUILD_TESTS) if(FT_BUILD_TESTS)
enable_testing() enable_testing()
......
# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(download_cub)
if(CMAKE_VERSION VERSION_LESS 3.11)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
include(FetchContent)
set(cub_URL "https://github.com/NVlabs/cub/archive/1.15.0.tar.gz")
set(cub_HASH "SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685")
FetchContent_Declare(cub
URL ${cub_URL}
URL_HASH ${cub_HASH}
)
FetchContent_GetProperties(cub)
if(NOT cub)
message(STATUS "Downloading cub")
FetchContent_Populate(cub)
endif()
message(STATUS "cub is downloaded to ${cub_SOURCE_DIR}")
add_library(cub INTERFACE)
target_include_directories(cub INTERFACE ${cub_SOURCE_DIR})
endfunction()
download_cub()
...@@ -5,7 +5,6 @@ include(transform) ...@@ -5,7 +5,6 @@ include(transform)
set(srcs set(srcs
mutual_information_cpu.cu mutual_information_cpu.cu
utils.cu
) )
if(NOT FT_WITH_CUDA) if(NOT FT_WITH_CUDA)
...@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs}) ...@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs})
target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES}) target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES})
# for <torch/extension.h> # for <torch/extension.h>
target_include_directories(mutual_information_core PUBLIC ${PYTHON_INCLUDE_DIRS}) target_include_directories(mutual_information_core PUBLIC ${PYTHON_INCLUDE_DIRS})
# see https://github.com/NVIDIA/thrust/issues/1401
# and https://github.com/k2-fsa/k2/pull/917
target_compile_definitions(mutual_information_core PUBLIC CUB_WRAPPED_NAMESPACE=fast_rnnt)
target_compile_definitions(mutual_information_core PUBLIC THRUST_NS_QUALIFIER=thrust)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
target_link_libraries(mutual_information_core PUBLIC cub)
endif()
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/utils.h"
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src) {
TORCH_CHECK(src.dim() == 1, "Only support one dimension tensor");
TORCH_CHECK(src.scalar_type() == torch::kLong, "Only support LongTensor");
TORCH_CHECK(src.is_contiguous(), "Expected to be contiguous");
int32_t dim = src.numel();
if (src.device().type() == torch::kCPU) {
int64_t min_value = std::numeric_limits<int64_t>::max();
int64_t *src_data = src.data_ptr<int64_t>();
for (int32_t i = dim - 1; i >= 0; --i) {
min_value = std::min(src_data[i], min_value);
src[i] = min_value;
}
} else {
#ifdef FT_WITH_CUDA
TORCH_CHECK(src.device().is_cuda());
internal::MinOp<int64_t> min_op;
auto src_data = src.data_ptr<int64_t>();
internal::ConstReversedPtr<int64_t> src_ptr =
internal::ConstReversedPtr<int64_t>(src_data, dim);
internal::ReversedPtr<int64_t> dest_ptr =
internal::ReversedPtr<int64_t>(src_data, dim);
// The first time is to determine temporary device storage requirements.
std::size_t temp_storage_bytes = 0;
auto s = cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
auto d_temp = torch::empty({static_cast<int64_t>(temp_storage_bytes)},
torch::dtype(torch::kInt8).device(src.device()));
int8_t *d_temp_storage = d_temp.data_ptr<int8_t>();
s = cub::DeviceScan::InclusiveScan(d_temp_storage, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
TORCH_CHECK(false, "Please build with -DFT_WITH_CUDA=ON");
#endif // FT_WITH_CUDA
}
}
} // namespace fast_rnnt
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_UTILS_H_
#define FAST_RNNT_CSRC_UTILS_H_
#include "torch/extension.h"
#ifdef FT_WITH_CUDA
#include "cub/cub.cuh" // NOLINT
namespace fast_rnnt {
namespace internal {
template <typename T> struct MinOp {
__host__ __device__ __forceinline__ T operator()(const T &a,
const T &b) const {
return (a > b) ? b : a;
}
};
// Will be used (as both InputIterator and OutputIterator) in
// MonotonicLowerBound to call cub::DeviceScan::InclusiveScan.
template <typename T> struct ConstReversedPtr {
const T *data;
// data points to the last element now
explicit ConstReversedPtr(const T *data, int32_t size)
: data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ const T &operator[](int32_t i) const {
return data[-i];
}
__host__ __device__ __forceinline__ ConstReversedPtr
operator+(int32_t n) const {
ConstReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ const T &operator*() const {
return *data;
}
};
template <typename T> struct ReversedPtr {
T *data;
// data points to the last element now
explicit ReversedPtr(T *data, int32_t size) : data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ T &operator[](int32_t i) {
return data[-i];
}
__host__ __device__ __forceinline__ ReversedPtr operator+(int32_t n) const {
ReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ T &operator*() { return *data; }
};
} // namespace internal
} // namespace fast_rnnt
namespace std {
// vaule_type is required by cub::DeviceScan::InclusiveSum
template <typename T>
struct iterator_traits<fast_rnnt::internal::ConstReversedPtr<T>> {
typedef T value_type;
};
template <typename T>
struct iterator_traits<fast_rnnt::internal::ReversedPtr<T>> {
typedef T value_type;
};
} // namespace std
#endif // FT_WITH_CUDA
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_UTILS_H_
...@@ -6,7 +6,6 @@ include(transform) ...@@ -6,7 +6,6 @@ include(transform)
set(fast_rnnt_srcs set(fast_rnnt_srcs
fast_rnnt.cu fast_rnnt.cu
mutual_information.cu mutual_information.cu
utils.cu
) )
if(NOT FT_WITH_CUDA) if(NOT FT_WITH_CUDA)
...@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA) ...@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA)
endif() endif()
pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs}) pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs})
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core) target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)
......
...@@ -20,11 +20,9 @@ ...@@ -20,11 +20,9 @@
#include "fast_rnnt/python/csrc/fast_rnnt.h" #include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/mutual_information.h" #include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/utils.h"
PYBIND11_MODULE(_fast_rnnt, m) { PYBIND11_MODULE(_fast_rnnt, m) {
m.doc() = "Python wrapper for Fast Rnnt."; m.doc() = "Python wrapper for Fast Rnnt.";
fast_rnnt::PybindMutualInformation(m); fast_rnnt::PybindMutualInformation(m);
fast_rnnt::PybindUtils(m);
} }
...@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) { ...@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) {
}, },
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"), py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"),
py::arg("ans_grad")); py::arg("ans_grad"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
} }
} // namespace fast_rnnt } // namespace fast_rnnt
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/utils.h"
#include "fast_rnnt/python/csrc/utils.h"
namespace fast_rnnt {
void PybindUtils(py::module &m) {
m.def(
"monotonic_lower_bound_",
[](torch::Tensor &src) -> void {
DeviceGuard guard(src.device());
if (src.dim() == 1) {
MonotonicLowerBound(src);
} else if (src.dim() == 2) {
int32_t dim0 = src.sizes()[0];
for (int32_t i = 0; i < dim0; ++i) {
auto sub = src.index({i});
MonotonicLowerBound(sub);
}
} else {
TORCH_CHECK(false,
"Only support 1 dimension and 2 dimensions tensor");
}
},
py::arg("src"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
}
} // namespace fast_rnnt
/**
* @brief python wrappers for utils.h
*
* @copyright
* Copyright 2022 Xiaomi Corp. (author: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_UTILS_H_
#define FAST_RNNT_PYTHON_CSRC_UTILS_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace fast_rnnt {
void PybindUtils(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_UTILS_H_
from _fast_rnnt import monotonic_lower_bound_
from _fast_rnnt import with_cuda from _fast_rnnt import with_cuda
from .mutual_information import mutual_information_recursion from .mutual_information import mutual_information_recursion
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import os import os
import fast_rnnt
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -521,6 +520,40 @@ def rnnt_loss( ...@@ -521,6 +520,40 @@ def rnnt_loss(
) )
def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor:
"""Compute a monotonically increasing lower bound of the tensor `x` on the
last dimension. The basic idea is: we traverse the tensor in reverse order,
and update current element with the following statement,
min_value = min(x[i], min_value)
x[i] = min_value
>>> import torch
>>> x = torch.tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32)
>>> _monotonic_lower_bound(x)
tensor([0, 1, 1, 3, 5, 5, 8], dtype=torch.int32)
>>> x
tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32)
>>> x = torch.randint(20, (3, 6), dtype=torch.int32)
>>> x
tensor([[12, 18, 5, 4, 18, 17],
[11, 14, 14, 3, 10, 4],
[19, 3, 8, 13, 7, 19]], dtype=torch.int32)
>>> _monotonic_lower_bound(x)
tensor([[ 4, 4, 4, 4, 17, 17],
[ 3, 3, 3, 3, 4, 4],
[ 3, 3, 7, 7, 7, 19]], dtype=torch.int32)
Args:
x:
The source tensor.
Returns:
Returns a tensor which is monotonic on the last dimension
(i.e. satisfiy `x[i] <= x[i+1]`).
"""
x = torch.flip(x, dims=(-1,))
x, _ = torch.cummin(x, dim=-1)
x = torch.flip(x, dims=(-1,))
return x
def _adjust_pruning_lower_bound( def _adjust_pruning_lower_bound(
s_begin: torch.Tensor, s_range: int s_begin: torch.Tensor, s_range: int
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -532,11 +565,10 @@ def _adjust_pruning_lower_bound( ...@@ -532,11 +565,10 @@ def _adjust_pruning_lower_bound(
- s_begin[i + 1] - s_begin[i] < s_range, which means that we can't skip - s_begin[i + 1] - s_begin[i] < s_range, which means that we can't skip
any symbols. any symbols.
To make it monotonic increasing, we can use `monotonic_lower_bound` function To make it monotonic increasing, we can use `_monotonic_lower_bound` above,
in k2, which guarantees `s_begin[i] <= s_begin[i + 1]`. The main idea is: which guarantees `s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`, the initial `min_value` is set to `min_value = min(a_begin[i], min_value)`.
`inf`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range` The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constraint is a little tricky. We first transform `s_begin` with constraint is a little tricky. We first transform `s_begin` with
...@@ -559,13 +591,13 @@ def _adjust_pruning_lower_bound( ...@@ -559,13 +591,13 @@ def _adjust_pruning_lower_bound(
""" """
# s_begin (B, T) # s_begin (B, T)
(B, T) = s_begin.shape (B, T) = s_begin.shape
fast_rnnt.monotonic_lower_bound_(s_begin) _monotonic_lower_bound(s_begin)
# do the magic transformation # do the magic transformation
s_begin = -( s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device) s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
) )
# make the transformed tensor to be non-decreasing # make the transformed tensor to be non-decreasing
fast_rnnt.monotonic_lower_bound_(s_begin) _monotonic_lower_bound(s_begin)
# make start symbol to be zero. # make start symbol to be zero.
s_begin = torch.where(s_begin < 0, 0, s_begin) s_begin = torch.where(s_begin < 0, 0, s_begin)
# do the magic transformation again to recover s_begin # do the magic transformation again to recover s_begin
......
...@@ -473,7 +473,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -473,7 +473,7 @@ class TestRnntLoss(unittest.TestCase):
rnnt_type=rnnt_type, rnnt_type=rnnt_type,
) )
print(f"Unpruned rnnt loss with {rnnt_loss} rnnt : {fast_loss}") print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {fast_loss}")
# pruning # pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
...@@ -578,7 +578,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -578,7 +578,7 @@ class TestRnntLoss(unittest.TestCase):
) )
S0 = 2 S0 = 2
if rnnt_type == "regular": if rnnt_type != "regular":
S0 = 1 S0 = 1
for r in range(S0, S + 2): for r in range(S0, S + 2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment