Unverified Commit 2c2dc4b9 authored by Daniel Povey's avatar Daniel Povey Committed by GitHub
Browse files

Merge pull request #16 from pkufool/contiguous

Update rnnt_loss
parents c268c3d5 03b82cc4
......@@ -46,7 +46,7 @@ message(STATUS "Enabled languages: ${languages}")
project(fast_rnnt ${languages})
set(FT_VERSION "1.0")
set(FT_VERSION "1.2")
set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel)
set(DEFAULT_BUILD_TYPE "Release")
......@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include(pybind11)
include(torch)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
# CUB is included in CUDA toolkit 11.0 and above
include(cub)
endif()
if(FT_BUILD_TESTS)
enable_testing()
......
# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(download_cub)
if(CMAKE_VERSION VERSION_LESS 3.11)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
include(FetchContent)
set(cub_URL "https://github.com/NVlabs/cub/archive/1.15.0.tar.gz")
set(cub_HASH "SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685")
FetchContent_Declare(cub
URL ${cub_URL}
URL_HASH ${cub_HASH}
)
FetchContent_GetProperties(cub)
if(NOT cub)
message(STATUS "Downloading cub")
FetchContent_Populate(cub)
endif()
message(STATUS "cub is downloaded to ${cub_SOURCE_DIR}")
add_library(cub INTERFACE)
target_include_directories(cub INTERFACE ${cub_SOURCE_DIR})
endfunction()
download_cub()
......@@ -5,7 +5,6 @@ include(transform)
set(srcs
mutual_information_cpu.cu
utils.cu
)
if(NOT FT_WITH_CUDA)
......@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs})
target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES})
# for <torch/extension.h>
target_include_directories(mutual_information_core PUBLIC ${PYTHON_INCLUDE_DIRS})
# see https://github.com/NVIDIA/thrust/issues/1401
# and https://github.com/k2-fsa/k2/pull/917
target_compile_definitions(mutual_information_core PUBLIC CUB_WRAPPED_NAMESPACE=fast_rnnt)
target_compile_definitions(mutual_information_core PUBLIC THRUST_NS_QUALIFIER=thrust)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
target_link_libraries(mutual_information_core PUBLIC cub)
endif()
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/utils.h"
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src) {
TORCH_CHECK(src.dim() == 1, "Only support one dimension tensor");
TORCH_CHECK(src.scalar_type() == torch::kLong, "Only support LongTensor");
TORCH_CHECK(src.is_contiguous(), "Expected to be contiguous");
int32_t dim = src.numel();
if (src.device().type() == torch::kCPU) {
int64_t min_value = std::numeric_limits<int64_t>::max();
int64_t *src_data = src.data_ptr<int64_t>();
for (int32_t i = dim - 1; i >= 0; --i) {
min_value = std::min(src_data[i], min_value);
src[i] = min_value;
}
} else {
#ifdef FT_WITH_CUDA
TORCH_CHECK(src.device().is_cuda());
internal::MinOp<int64_t> min_op;
auto src_data = src.data_ptr<int64_t>();
internal::ConstReversedPtr<int64_t> src_ptr =
internal::ConstReversedPtr<int64_t>(src_data, dim);
internal::ReversedPtr<int64_t> dest_ptr =
internal::ReversedPtr<int64_t>(src_data, dim);
// The first time is to determine temporary device storage requirements.
std::size_t temp_storage_bytes = 0;
auto s = cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
auto d_temp = torch::empty({static_cast<int64_t>(temp_storage_bytes)},
torch::dtype(torch::kInt8).device(src.device()));
int8_t *d_temp_storage = d_temp.data_ptr<int8_t>();
s = cub::DeviceScan::InclusiveScan(d_temp_storage, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
TORCH_CHECK(false, "Please build with -DFT_WITH_CUDA=ON");
#endif // FT_WITH_CUDA
}
}
} // namespace fast_rnnt
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_UTILS_H_
#define FAST_RNNT_CSRC_UTILS_H_
#include "torch/extension.h"
#ifdef FT_WITH_CUDA
#include "cub/cub.cuh" // NOLINT
namespace fast_rnnt {
namespace internal {
template <typename T> struct MinOp {
__host__ __device__ __forceinline__ T operator()(const T &a,
const T &b) const {
return (a > b) ? b : a;
}
};
// Will be used (as both InputIterator and OutputIterator) in
// MonotonicLowerBound to call cub::DeviceScan::InclusiveScan.
template <typename T> struct ConstReversedPtr {
const T *data;
// data points to the last element now
explicit ConstReversedPtr(const T *data, int32_t size)
: data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ const T &operator[](int32_t i) const {
return data[-i];
}
__host__ __device__ __forceinline__ ConstReversedPtr
operator+(int32_t n) const {
ConstReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ const T &operator*() const {
return *data;
}
};
template <typename T> struct ReversedPtr {
T *data;
// data points to the last element now
explicit ReversedPtr(T *data, int32_t size) : data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ T &operator[](int32_t i) {
return data[-i];
}
__host__ __device__ __forceinline__ ReversedPtr operator+(int32_t n) const {
ReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ T &operator*() { return *data; }
};
} // namespace internal
} // namespace fast_rnnt
namespace std {
// vaule_type is required by cub::DeviceScan::InclusiveSum
template <typename T>
struct iterator_traits<fast_rnnt::internal::ConstReversedPtr<T>> {
typedef T value_type;
};
template <typename T>
struct iterator_traits<fast_rnnt::internal::ReversedPtr<T>> {
typedef T value_type;
};
} // namespace std
#endif // FT_WITH_CUDA
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_UTILS_H_
......@@ -6,7 +6,6 @@ include(transform)
set(fast_rnnt_srcs
fast_rnnt.cu
mutual_information.cu
utils.cu
)
if(NOT FT_WITH_CUDA)
......@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA)
endif()
pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs})
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)
......
......@@ -20,11 +20,9 @@
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/utils.h"
PYBIND11_MODULE(_fast_rnnt, m) {
m.doc() = "Python wrapper for Fast Rnnt.";
fast_rnnt::PybindMutualInformation(m);
fast_rnnt::PybindUtils(m);
}
......@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) {
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"),
py::arg("ans_grad"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
}
} // namespace fast_rnnt
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/utils.h"
#include "fast_rnnt/python/csrc/utils.h"
namespace fast_rnnt {
void PybindUtils(py::module &m) {
m.def(
"monotonic_lower_bound_",
[](torch::Tensor &src) -> void {
DeviceGuard guard(src.device());
if (src.dim() == 1) {
MonotonicLowerBound(src);
} else if (src.dim() == 2) {
int32_t dim0 = src.sizes()[0];
for (int32_t i = 0; i < dim0; ++i) {
auto sub = src.index({i});
MonotonicLowerBound(sub);
}
} else {
TORCH_CHECK(false,
"Only support 1 dimension and 2 dimensions tensor");
}
},
py::arg("src"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
}
} // namespace fast_rnnt
/**
* @brief python wrappers for utils.h
*
* @copyright
* Copyright 2022 Xiaomi Corp. (author: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_UTILS_H_
#define FAST_RNNT_PYTHON_CSRC_UTILS_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace fast_rnnt {
void PybindUtils(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_UTILS_H_
from _fast_rnnt import monotonic_lower_bound_
from _fast_rnnt import with_cuda
from .mutual_information import mutual_information_recursion
......
......@@ -285,9 +285,10 @@ def mutual_information_recursion(
for s_begin, t_begin, s_end, t_end in boundary.tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
# The following assertions are for efficiency
assert px.is_contiguous()
assert py.is_contiguous()
# The following statements are for efficiency
px, py = px.contiguous(), py.contiguous()
pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad)
......@@ -378,8 +379,9 @@ def joint_mutual_information_recursion(
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
# The following statements are for efficiency
px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
# The following assertions are for efficiency
assert px_tot.ndim == 3
assert py_tot.ndim == 3
......
This diff is collapsed.
......@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
assert px.shape == (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(px=px, py=py, boundary=None)
m = fast_rnnt.mutual_information_recursion(
px=px, py=py, boundary=None
)
if device == torch.device("cpu"):
expected = -m
......@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for modified in [True, False]:
for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# lm: [B][S+1][C]
lm = lm_.to(device)
......@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert (
px.shape == (B, S, T)
if rnnt_type != "regular"
else (B, S, T + 1)
)
assert px.shape == (B, S, T) if modified else (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(
......@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
# compare with torchaudio rnnt_loss
if self.has_torch_rnnt_loss and not modified:
if self.has_torch_rnnt_loss and rnnt_type == "regular":
import torchaudio.functional
m = torchaudio.functional.rnnt_loss(
......@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
assert torch.allclose(m, expected.to(device))
......@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
torch_grad = torch.autograd.grad(torch_loss, logits2)
torch_grad = torch_grad[0]
assert torch.allclose(fast_loss, torch_loss, atol=1e-2, rtol=1e-2)
assert torch.allclose(
fast_loss, torch_loss, atol=1e-2, rtol=1e-2
)
assert torch.allclose(fast_grad, torch_grad, atol=1e-2, rtol=1e-2)
assert torch.allclose(
fast_grad, torch_grad, atol=1e-2, rtol=1e-2
)
def test_rnnt_loss_smoothed(self):
B = 1
......@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for modified in [True, False]:
for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
......@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
)
print(
f"Unpruned rnnt loss with modified {modified} : {fast_loss}"
)
print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {fast_loss}")
# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
......@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)
......@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)
logits = pruned_am + pruned_lm
......@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
reduction="none",
)
print(f"Pruning loss with range {r} : {pruned_loss}")
# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
......@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
print(f"B = {B}, T = {T}, S = {S}, C = {C}")
for modified in [True, False]:
for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
......@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
reduction="none",
)
print(
f"Unpruned rnnt loss with modified {modified} : {loss}"
)
print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")
# pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
......@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)
S0 = 2
if modified:
if rnnt_type != "regular":
S0 = 1
for r in range(S0, S + 2):
......@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
rnnt_type=rnnt_type,
reduction="none",
)
print(f"Pruned loss with range {r} : {pruned_loss}")
if __name__ == "__main__":
unittest.main()
......@@ -36,11 +36,10 @@ class BuildExtension(build_ext):
system_make_args = os.environ.get("MAKEFLAGS", "")
if cmake_args == "":
cmake_args = "-DCMAKE_BUILD_TYPE=Release"
cmake_args = "-DCMAKE_BUILD_TYPE=Release -DFT_BUILD_TESTS=OFF"
if make_args == "" and system_make_args == "":
print("For fast compilation, run:")
print('export FT_MAKE_ARGS="-j"; python setup.py install')
make_args = ' -j '
if "PYTHON_EXECUTABLE" not in cmake_args:
print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment