Commit 0a9698b7 authored by pkufool's avatar pkufool
Browse files

Minor fixes

parent b0b548c9
......@@ -18,7 +18,6 @@ function(download_googltest)
# FetchContent is available since 3.11,
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
# so that it can be used in lower CMake versions.
message(STATUS "Use FetchContent provided by k2")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
......
......@@ -25,20 +25,32 @@
namespace fast_rnnt {
void PybindUtils(py::module &m) {
m.def("monotonic_lower_bound_", [](torch::Tensor &src) -> void {
m.def(
"monotonic_lower_bound_",
[](torch::Tensor &src) -> void {
DeviceGuard guard(src.device());
if (src.dim() == 1) {
MonotonicLowerBound(src);
} else if (src.dim() == 2) {
int32_t dim0 = src.sizes()[0];
for (int32_t i = 0; i < dim0; ++i) {
auto sub = src.index({i, torch::indexing::Slice()});
auto sub = src.index({i});
MonotonicLowerBound(sub);
}
} else {
TORCH_CHECK(false, "Only support 1 dimension and 2 dimensions tensor");
TORCH_CHECK(false,
"Only support 1 dimension and 2 dimensions tensor");
}
}, py::arg("src"));
},
py::arg("src"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
}
} // namespace fast_rnnt
from _fast_rnnt import monotonic_lower_bound_
from _fast_rnnt import with_cuda
from .mutual_information import mutual_information_recursion
from .mutual_information import joint_mutual_information_recursion
......
......@@ -34,7 +34,7 @@ class TestMutualInformation(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.devices = [torch.device("cpu")]
if torch.cuda.is_available():
if torch.cuda.is_available() and fast_rnnt.with_cuda():
cls.devices.append(torch.device("cuda", 0))
if torch.cuda.device_count() > 1:
torch.cuda.set_device(1)
......
......@@ -32,7 +32,7 @@ class TestRnntLoss(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.devices = [torch.device("cpu")]
if torch.cuda.is_available():
if torch.cuda.is_available() and fast_rnnt.with_cuda():
cls.devices.append(torch.device("cuda", 0))
if torch.cuda.device_count() > 1:
torch.cuda.set_device(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment