Commit 0a9698b7 authored by pkufool's avatar pkufool
Browse files

Minor fixes

parent b0b548c9
...@@ -18,7 +18,6 @@ function(download_googltest) ...@@ -18,7 +18,6 @@ function(download_googltest)
# FetchContent is available since 3.11, # FetchContent is available since 3.11,
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
# so that it can be used in lower CMake versions. # so that it can be used in lower CMake versions.
message(STATUS "Use FetchContent provided by k2")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif() endif()
......
...@@ -25,20 +25,32 @@ ...@@ -25,20 +25,32 @@
namespace fast_rnnt { namespace fast_rnnt {
void PybindUtils(py::module &m) { void PybindUtils(py::module &m) {
m.def("monotonic_lower_bound_", [](torch::Tensor &src) -> void { m.def(
DeviceGuard guard(src.device()); "monotonic_lower_bound_",
if (src.dim() == 1) { [](torch::Tensor &src) -> void {
MonotonicLowerBound(src); DeviceGuard guard(src.device());
} else if (src.dim() == 2) { if (src.dim() == 1) {
int32_t dim0 = src.sizes()[0]; MonotonicLowerBound(src);
for (int32_t i = 0; i < dim0; ++i) { } else if (src.dim() == 2) {
auto sub = src.index({i, torch::indexing::Slice()}); int32_t dim0 = src.sizes()[0];
MonotonicLowerBound(sub); for (int32_t i = 0; i < dim0; ++i) {
} auto sub = src.index({i});
} else { MonotonicLowerBound(sub);
TORCH_CHECK(false, "Only support 1 dimension and 2 dimensions tensor"); }
} } else {
}, py::arg("src")); TORCH_CHECK(false,
"Only support 1 dimension and 2 dimensions tensor");
}
},
py::arg("src"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
} }
} // namespace fast_rnnt } // namespace fast_rnnt
from _fast_rnnt import monotonic_lower_bound_ from _fast_rnnt import monotonic_lower_bound_
from _fast_rnnt import with_cuda
from .mutual_information import mutual_information_recursion from .mutual_information import mutual_information_recursion
from .mutual_information import joint_mutual_information_recursion from .mutual_information import joint_mutual_information_recursion
......
...@@ -34,7 +34,7 @@ class TestMutualInformation(unittest.TestCase): ...@@ -34,7 +34,7 @@ class TestMutualInformation(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.devices = [torch.device("cpu")] cls.devices = [torch.device("cpu")]
if torch.cuda.is_available(): if torch.cuda.is_available() and fast_rnnt.with_cuda():
cls.devices.append(torch.device("cuda", 0)) cls.devices.append(torch.device("cuda", 0))
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
torch.cuda.set_device(1) torch.cuda.set_device(1)
......
...@@ -32,7 +32,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -32,7 +32,7 @@ class TestRnntLoss(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.devices = [torch.device("cpu")] cls.devices = [torch.device("cpu")]
if torch.cuda.is_available(): if torch.cuda.is_available() and fast_rnnt.with_cuda():
cls.devices.append(torch.device("cuda", 0)) cls.devices.append(torch.device("cuda", 0))
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
torch.cuda.set_device(1) torch.cuda.set_device(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment