Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
0a9698b7
Commit
0a9698b7
authored
Mar 08, 2022
by
pkufool
Browse files
Minor fixes
parent
b0b548c9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
17 deletions
+29
-17
cmake/googletest.cmake
cmake/googletest.cmake
+0
-1
fast_rnnt/python/csrc/utils.cu
fast_rnnt/python/csrc/utils.cu
+26
-14
fast_rnnt/python/fast_rnnt/__init__.py
fast_rnnt/python/fast_rnnt/__init__.py
+1
-0
fast_rnnt/python/tests/mutual_information_test.py
fast_rnnt/python/tests/mutual_information_test.py
+1
-1
fast_rnnt/python/tests/rnnt_loss_test.py
fast_rnnt/python/tests/rnnt_loss_test.py
+1
-1
No files found.
cmake/googletest.cmake
View file @
0a9698b7
...
@@ -18,7 +18,6 @@ function(download_googltest)
...
@@ -18,7 +18,6 @@ function(download_googltest)
# FetchContent is available since 3.11,
# FetchContent is available since 3.11,
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
# so that it can be used in lower CMake versions.
# so that it can be used in lower CMake versions.
message
(
STATUS
"Use FetchContent provided by k2"
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_SOURCE_DIR
}
/cmake/Modules
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_SOURCE_DIR
}
/cmake/Modules
)
endif
()
endif
()
...
...
fast_rnnt/python/csrc/utils.cu
View file @
0a9698b7
...
@@ -25,20 +25,32 @@
...
@@ -25,20 +25,32 @@
namespace
fast_rnnt
{
namespace
fast_rnnt
{
void
PybindUtils
(
py
::
module
&
m
)
{
void
PybindUtils
(
py
::
module
&
m
)
{
m
.
def
(
"monotonic_lower_bound_"
,
[](
torch
::
Tensor
&
src
)
->
void
{
m
.
def
(
DeviceGuard
guard
(
src
.
device
());
"monotonic_lower_bound_"
,
if
(
src
.
dim
()
==
1
)
{
[](
torch
::
Tensor
&
src
)
->
void
{
MonotonicLowerBound
(
src
);
DeviceGuard
guard
(
src
.
device
());
}
else
if
(
src
.
dim
()
==
2
)
{
if
(
src
.
dim
()
==
1
)
{
int32_t
dim0
=
src
.
sizes
()[
0
];
MonotonicLowerBound
(
src
);
for
(
int32_t
i
=
0
;
i
<
dim0
;
++
i
)
{
}
else
if
(
src
.
dim
()
==
2
)
{
auto
sub
=
src
.
index
({
i
,
torch
::
indexing
::
Slice
()});
int32_t
dim0
=
src
.
sizes
()[
0
];
MonotonicLowerBound
(
sub
);
for
(
int32_t
i
=
0
;
i
<
dim0
;
++
i
)
{
}
auto
sub
=
src
.
index
({
i
});
}
else
{
MonotonicLowerBound
(
sub
);
TORCH_CHECK
(
false
,
"Only support 1 dimension and 2 dimensions tensor"
);
}
}
}
else
{
},
py
::
arg
(
"src"
));
TORCH_CHECK
(
false
,
"Only support 1 dimension and 2 dimensions tensor"
);
}
},
py
::
arg
(
"src"
));
m
.
def
(
"with_cuda"
,
[]()
->
bool
{
#ifdef FT_WITH_CUDA
return
true
;
#else
return
false
;
#endif
});
}
}
}
// namespace fast_rnnt
}
// namespace fast_rnnt
fast_rnnt/python/fast_rnnt/__init__.py
View file @
0a9698b7
from
_fast_rnnt
import
monotonic_lower_bound_
from
_fast_rnnt
import
monotonic_lower_bound_
from
_fast_rnnt
import
with_cuda
from
.mutual_information
import
mutual_information_recursion
from
.mutual_information
import
mutual_information_recursion
from
.mutual_information
import
joint_mutual_information_recursion
from
.mutual_information
import
joint_mutual_information_recursion
...
...
fast_rnnt/python/tests/mutual_information_test.py
View file @
0a9698b7
...
@@ -34,7 +34,7 @@ class TestMutualInformation(unittest.TestCase):
...
@@ -34,7 +34,7 @@ class TestMutualInformation(unittest.TestCase):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
devices
=
[
torch
.
device
(
"cpu"
)]
cls
.
devices
=
[
torch
.
device
(
"cpu"
)]
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
()
and
fast_rnnt
.
with_cuda
()
:
cls
.
devices
.
append
(
torch
.
device
(
"cuda"
,
0
))
cls
.
devices
.
append
(
torch
.
device
(
"cuda"
,
0
))
if
torch
.
cuda
.
device_count
()
>
1
:
if
torch
.
cuda
.
device_count
()
>
1
:
torch
.
cuda
.
set_device
(
1
)
torch
.
cuda
.
set_device
(
1
)
...
...
fast_rnnt/python/tests/rnnt_loss_test.py
View file @
0a9698b7
...
@@ -32,7 +32,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -32,7 +32,7 @@ class TestRnntLoss(unittest.TestCase):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
devices
=
[
torch
.
device
(
"cpu"
)]
cls
.
devices
=
[
torch
.
device
(
"cpu"
)]
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
()
and
fast_rnnt
.
with_cuda
()
:
cls
.
devices
.
append
(
torch
.
device
(
"cuda"
,
0
))
cls
.
devices
.
append
(
torch
.
device
(
"cuda"
,
0
))
if
torch
.
cuda
.
device_count
()
>
1
:
if
torch
.
cuda
.
device_count
()
>
1
:
torch
.
cuda
.
set_device
(
1
)
torch
.
cuda
.
set_device
(
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment