Commit b0b548c9 authored by pkufool's avatar pkufool
Browse files

Remove k2 dependency

parent 5a3e2552
......@@ -2,4 +2,6 @@
.idea
venv*
deploy*
__pycache__/*
**/__pycache__
**/build*
Testing*
......@@ -14,8 +14,13 @@ cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
set(languages CXX)
set(_FT_WITH_CUDA ON)
set(CMAKE_CXX_STANDARD 14)
# the following settings are modified from cub/CMakeLists.txt
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}")
find_program(FT_HAS_NVCC nvcc)
if(NOT FT_HAS_NVCC AND "$ENV{CUDACXX}" STREQUAL "")
......@@ -55,7 +60,7 @@ elseif(NOT CMAKE_BUILD_TYPE IN_LIST ALLOWABLE_BUILD_TYPES)
choose one from ${ALLOWABLE_BUILD_TYPES}")
endif()
option(FT_BUILD_TESTS "Whether to build tests or not" OFF)
option(FT_BUILD_TESTS "Whether to build tests or not" ON)
option(BUILD_SHARED_LIBS "Whether to build shared libs" ON)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
......@@ -128,6 +133,10 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include(pybind11)
include(torch)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
# CUB is included in CUDA toolkit 11.0 and above
include(cub)
endif()
if(FT_BUILD_TESTS)
enable_testing()
......
# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(download_cub)
if(CMAKE_VERSION VERSION_LESS 3.11)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
include(FetchContent)
set(cub_URL "https://github.com/NVlabs/cub/archive/1.15.0.tar.gz")
set(cub_HASH "SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685")
FetchContent_Declare(cub
URL ${cub_URL}
URL_HASH ${cub_HASH}
)
FetchContent_GetProperties(cub)
if(NOT cub)
message(STATUS "Downloading cub")
FetchContent_Populate(cub)
endif()
message(STATUS "cub is downloaded to ${cub_SOURCE_DIR}")
add_library(cub INTERFACE)
target_include_directories(cub INTERFACE ${cub_SOURCE_DIR})
endfunction()
download_cub()
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This function is used to copy foo.cu to foo.cc
# Usage:
#
# transform(OUTPUT_VARIABLE output_variable_name SRCS foo.cu bar.cu)
#
function(transform)
set(optional_args "") # there are no optional arguments
set(one_value_arg OUTPUT_VARIABLE)
set(multi_value_args SRCS)
cmake_parse_arguments(MY "${optional_args}" "${one_value_arg}" "${multi_value_args}" ${ARGN})
foreach(src IN LISTS MY_SRCS)
get_filename_component(src_name ${src} NAME_WE)
get_filename_component(src_dir ${src} DIRECTORY)
set(dst ${CMAKE_CURRENT_BINARY_DIR}/${src_dir}/${src_name}.cc)
list(APPEND ans ${dst})
message(STATUS "Renaming ${CMAKE_CURRENT_SOURCE_DIR}/${src} to ${dst}")
configure_file(${src} ${dst})
endforeach()
set(${MY_OUTPUT_VARIABLE} ${ans} PARENT_SCOPE)
endfunction()
include_directories(${CMAKE_SOURCE_DIR})
# it is located in fast_rnnt/cmake/transform.cmake
include(transform)
set(srcs
mutual_information_cpu.cc
mutual_information_cpu.cu
utils.cu
)
if(NOT FT_WITH_CUDA)
transform(OUTPUT_VARIABLE srcs SRCS ${srcs})
else()
list(APPEND srcs mutual_information_cuda.cu)
endif()
add_library(mutual_information_core ${srcs})
target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES})
if(FT_WITH_CUDA)
set(cuda_srcs mutual_information_cuda.cu)
add_library(mutual_information_core_cuda ${cuda_srcs})
target_link_libraries(mutual_information_core_cuda PUBLIC ${TORCH_LIBRARIES})
# for <torch/extension.h>
target_include_directories(mutual_information_core_cuda PUBLIC ${PYTHON_INCLUDE_DIRS})
target_link_libraries(mutual_information_core PUBLIC mutual_information_core_cuda)
# for <torch/extension.h>
target_include_directories(mutual_information_core PUBLIC ${PYTHON_INCLUDE_DIRS})
# see https://github.com/NVIDIA/thrust/issues/1401
# and https://github.com/k2-fsa/k2/pull/917
target_compile_definitions(mutual_information_core PUBLIC CUB_WRAPPED_NAMESPACE=fast_rnnt)
target_compile_definitions(mutual_information_core PUBLIC THRUST_NS_QUALIFIER=thrust)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
target_link_libraries(mutual_information_core PUBLIC cub)
endif()
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang, Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
......@@ -19,7 +19,7 @@
#ifndef FAST_RNNT_CSRC_DEVICE_GUARD_H_
#define FAST_RNNT_CSRC_DEVICE_GUARD_H_
#include <torch/script.h>
#include "torch/script.h"
// This file is modified from
// https://github.com/k2-fsa/k2/blob/master/k2/csrc/device_guard.h
......@@ -65,15 +65,23 @@ public:
private:
static int32_t GetDevice() {
#ifdef FT_WITH_CUDA
int32_t device;
auto s = cudaGetDevice(&device);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
return device;
#else
return -1;
#endif
}
static void SetDevice(int32_t device) {
#ifdef FT_WITH_CUDA
auto s = cudaSetDevice(device);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
return;
#endif
}
private:
......
......@@ -21,10 +21,9 @@
#ifndef FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#include <torch/extension.h>
#include <cmath>
#include <vector>
#include "torch/extension.h"
#ifdef __CUDA_ARCH__
#define FT_CUDA_HOSTDEV __host__ __device__
......
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/utils.h"
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src) {
TORCH_CHECK(src.dim() == 1, "Only support one dimension tensor");
TORCH_CHECK(src.scalar_type() == torch::kLong, "Only support LongTensor");
TORCH_CHECK(src.is_contiguous(), "Expected to be contiguous");
int32_t dim = src.numel();
if (src.device().type() == torch::kCPU) {
int64_t min_value = std::numeric_limits<int64_t>::max();
int64_t *src_data = src.data_ptr<int64_t>();
for (int32_t i = dim - 1; i >= 0; --i) {
min_value = std::min(src_data[i], min_value);
src[i] = min_value;
}
} else {
#ifdef FT_WITH_CUDA
TORCH_CHECK(src.device().is_cuda());
internal::MinOp<int64_t> min_op;
auto src_data = src.data_ptr<int64_t>();
internal::ConstReversedPtr<int64_t> src_ptr =
internal::ConstReversedPtr<int64_t>(src_data, dim);
internal::ReversedPtr<int64_t> dest_ptr =
internal::ReversedPtr<int64_t>(src_data, dim);
// The first time is to determine temporary device storage requirements.
std::size_t temp_storage_bytes = 0;
auto s = cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
auto d_temp = torch::empty({static_cast<int64_t>(temp_storage_bytes)},
torch::dtype(torch::kInt8).device(src.device()));
int8_t *d_temp_storage = d_temp.data_ptr<int8_t>();
s = cub::DeviceScan::InclusiveScan(d_temp_storage, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
TORCH_CHECK(false, "Please build with -DFT_WITH_CUDA=ON");
#endif // FT_WITH_CUDA
}
}
} // namespace fast_rnnt
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_UTILS_H_
#define FAST_RNNT_CSRC_UTILS_H_
#include "torch/extension.h"
#ifdef FT_WITH_CUDA
#include "cub/cub.cuh" // NOLINT
namespace fast_rnnt {
namespace internal {
template <typename T> struct MinOp {
__host__ __device__ __forceinline__ T operator()(const T &a,
const T &b) const {
return (a > b) ? b : a;
}
};
// Will be used (as both InputIterator and OutputIterator) in
// MonotonicLowerBound to call cub::DeviceScan::InclusiveScan.
template <typename T> struct ConstReversedPtr {
const T *data;
// data points to the last element now
explicit ConstReversedPtr(const T *data, int32_t size)
: data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ const T &operator[](int32_t i) const {
return data[-i];
}
__host__ __device__ __forceinline__ ConstReversedPtr
operator+(int32_t n) const {
ConstReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ const T &operator*() const {
return *data;
}
};
template <typename T> struct ReversedPtr {
T *data;
// data points to the last element now
explicit ReversedPtr(T *data, int32_t size) : data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ T &operator[](int32_t i) {
return data[-i];
}
__host__ __device__ __forceinline__ ReversedPtr operator+(int32_t n) const {
ReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ T &operator*() { return *data; }
};
} // namespace internal
} // namespace fast_rnnt
namespace std {
// vaule_type is required by cub::DeviceScan::InclusiveSum
template <typename T>
struct iterator_traits<fast_rnnt::internal::ConstReversedPtr<T>> {
typedef T value_type;
};
template <typename T>
struct iterator_traits<fast_rnnt::internal::ReversedPtr<T>> {
typedef T value_type;
};
} // namespace std
#endif // FT_WITH_CUDA
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_UTILS_H_
include_directories(${CMAKE_SOURCE_DIR})
pybind11_add_module(_fast_rnnt
include(transform)
# please keep the list sorted
set(fast_rnnt_srcs
fast_rnnt.cu
mutual_information.cu
utils.cu
)
if(NOT FT_WITH_CUDA)
transform(OUTPUT_VARIABLE fast_rnnt_srcs SRCS ${fast_rnnt_srcs})
endif()
pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs})
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)
if(UNIX AND NOT APPLE)
......
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/utils.h"
PYBIND11_MODULE(_fast_rnnt, m) {
m.doc() = "Python wrapper for Fast Rnnt.";
fast_rnnt::PybindMutualInformation(m);
fast_rnnt::PybindUtils(m);
}
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
#define FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
#include "pybind11/pybind11.h"
namespace py = pybind11;
#endif // FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_
......@@ -22,9 +22,8 @@
#include "fast_rnnt/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
PYBIND11_MODULE(_fast_rnnt, m) {
m.doc() = "Python wrapper for Mutual Information.";
namespace fast_rnnt {
void PybindMutualInformation(py::module &m) {
m.def(
"mutual_information_forward",
[](torch::Tensor px, torch::Tensor py,
......@@ -32,10 +31,10 @@ PYBIND11_MODULE(_fast_rnnt, m) {
torch::Tensor p) -> torch::Tensor {
fast_rnnt::DeviceGuard guard(px.device());
if (px.device().is_cpu()) {
return fast_rnnt::MutualInformationCpu(px, py, boundary, p);
return MutualInformationCpu(px, py, boundary, p);
} else {
#ifdef FT_WITH_CUDA
return fast_rnnt::MutualInformationCuda(px, py, boundary, p);
return MutualInformationCuda(px, py, boundary, p);
#else
TORCH_CHECK(false, "Failed to find native CUDA module, make sure "
"that you compiled the code with K2_WITH_CUDA.");
......@@ -52,12 +51,11 @@ PYBIND11_MODULE(_fast_rnnt, m) {
torch::Tensor ans_grad) -> std::vector<torch::Tensor> {
fast_rnnt::DeviceGuard guard(px.device());
if (px.device().is_cpu()) {
return fast_rnnt::MutualInformationBackwardCpu(px, py, boundary, p,
ans_grad);
return MutualInformationBackwardCpu(px, py, boundary, p, ans_grad);
} else {
#ifdef FT_WITH_CUDA
return fast_rnnt::MutualInformationBackwardCuda(px, py, boundary, p,
ans_grad, true);
return MutualInformationBackwardCuda(px, py, boundary, p, ans_grad,
true);
#else
TORCH_CHECK(false, "Failed to find native CUDA module, make sure "
"that you compiled the code with K2_WITH_CUDA.");
......@@ -68,3 +66,4 @@ PYBIND11_MODULE(_fast_rnnt, m) {
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"),
py::arg("ans_grad"));
}
} // namespace fast_rnnt
......@@ -21,8 +21,12 @@
#ifndef FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
#include "pybind11/pybind11.h"
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace py = pybind11;
namespace fast_rnnt {
#endif // FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
void PybindMutualInformation(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/utils.h"
#include "fast_rnnt/python/csrc/utils.h"
namespace fast_rnnt {
void PybindUtils(py::module &m) {
m.def("monotonic_lower_bound_", [](torch::Tensor &src) -> void {
DeviceGuard guard(src.device());
if (src.dim() == 1) {
MonotonicLowerBound(src);
} else if (src.dim() == 2) {
int32_t dim0 = src.sizes()[0];
for (int32_t i = 0; i < dim0; ++i) {
auto sub = src.index({i, torch::indexing::Slice()});
MonotonicLowerBound(sub);
}
} else {
TORCH_CHECK(false, "Only support 1 dimension and 2 dimensions tensor");
}
}, py::arg("src"));
}
} // namespace fast_rnnt
/**
* @brief python wrappers for utils.h
*
* @copyright
* Copyright 2022 Xiaomi Corp. (author: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_UTILS_H_
#define FAST_RNNT_PYTHON_CSRC_UTILS_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace fast_rnnt {
void PybindUtils(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_UTILS_H_
from _fast_rnnt import monotonic_lower_bound_
from .mutual_information import mutual_information_recursion
from .mutual_information import joint_mutual_information_recursion
......
......@@ -16,7 +16,7 @@
import os
import k2
import fast_rnnt
import torch
from torch import Tensor
from typing import Optional, Tuple, Union
......@@ -463,13 +463,13 @@ def _adjust_pruning_lower_bound(
"""
# s_begin (B, T)
(B, T) = s_begin.shape
s_begin = k2.monotonic_lower_bound(s_begin)
fast_rnnt.monotonic_lower_bound_(s_begin)
# do the magic transformation
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
# make the transformed tensor to be non-decreasing
s_begin = k2.monotonic_lower_bound(s_begin)
fast_rnnt.monotonic_lower_bound_(s_begin)
# make start symbol to be zero.
s_begin = torch.where(s_begin < 0, 0, s_begin)
# do the magic transformation again to recover s_begin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment