Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
2c2dc4b9
Unverified
Commit
2c2dc4b9
authored
Oct 31, 2022
by
Daniel Povey
Committed by
GitHub
Oct 31, 2022
Browse files
Merge pull request #16 from pkufool/contiguous
Update rnnt_loss
parents
c268c3d5
03b82cc4
Changes
15
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
434 additions
and
478 deletions
+434
-478
CMakeLists.txt
CMakeLists.txt
+1
-5
cmake/cub.cmake
cmake/cub.cmake
+0
-42
fast_rnnt/csrc/CMakeLists.txt
fast_rnnt/csrc/CMakeLists.txt
+0
-8
fast_rnnt/csrc/utils.cu
fast_rnnt/csrc/utils.cu
+0
-63
fast_rnnt/csrc/utils.h
fast_rnnt/csrc/utils.h
+0
-101
fast_rnnt/python/csrc/CMakeLists.txt
fast_rnnt/python/csrc/CMakeLists.txt
+0
-2
fast_rnnt/python/csrc/fast_rnnt.cu
fast_rnnt/python/csrc/fast_rnnt.cu
+0
-2
fast_rnnt/python/csrc/mutual_information.cu
fast_rnnt/python/csrc/mutual_information.cu
+8
-0
fast_rnnt/python/csrc/utils.cu
fast_rnnt/python/csrc/utils.cu
+0
-56
fast_rnnt/python/csrc/utils.h
fast_rnnt/python/csrc/utils.h
+0
-32
fast_rnnt/python/fast_rnnt/__init__.py
fast_rnnt/python/fast_rnnt/__init__.py
+0
-1
fast_rnnt/python/fast_rnnt/mutual_information.py
fast_rnnt/python/fast_rnnt/mutual_information.py
+6
-4
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+379
-129
fast_rnnt/python/tests/rnnt_loss_test.py
fast_rnnt/python/tests/rnnt_loss_test.py
+38
-30
setup.py
setup.py
+2
-3
No files found.
CMakeLists.txt
View file @
2c2dc4b9
...
@@ -46,7 +46,7 @@ message(STATUS "Enabled languages: ${languages}")
...
@@ -46,7 +46,7 @@ message(STATUS "Enabled languages: ${languages}")
project
(
fast_rnnt
${
languages
}
)
project
(
fast_rnnt
${
languages
}
)
set
(
FT_VERSION
"1.
0
"
)
set
(
FT_VERSION
"1.
2
"
)
set
(
ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel
)
set
(
ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel
)
set
(
DEFAULT_BUILD_TYPE
"Release"
)
set
(
DEFAULT_BUILD_TYPE
"Release"
)
...
@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
...
@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include
(
pybind11
)
include
(
pybind11
)
include
(
torch
)
include
(
torch
)
if
(
FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0
)
# CUB is included in CUDA toolkit 11.0 and above
include
(
cub
)
endif
()
if
(
FT_BUILD_TESTS
)
if
(
FT_BUILD_TESTS
)
enable_testing
()
enable_testing
()
...
...
cmake/cub.cmake
deleted
100644 → 0
View file @
c268c3d5
# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function
(
download_cub
)
if
(
CMAKE_VERSION VERSION_LESS 3.11
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_SOURCE_DIR
}
/cmake/Modules
)
endif
()
include
(
FetchContent
)
set
(
cub_URL
"https://github.com/NVlabs/cub/archive/1.15.0.tar.gz"
)
set
(
cub_HASH
"SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685"
)
FetchContent_Declare
(
cub
URL
${
cub_URL
}
URL_HASH
${
cub_HASH
}
)
FetchContent_GetProperties
(
cub
)
if
(
NOT cub
)
message
(
STATUS
"Downloading cub"
)
FetchContent_Populate
(
cub
)
endif
()
message
(
STATUS
"cub is downloaded to
${
cub_SOURCE_DIR
}
"
)
add_library
(
cub INTERFACE
)
target_include_directories
(
cub INTERFACE
${
cub_SOURCE_DIR
}
)
endfunction
()
download_cub
()
fast_rnnt/csrc/CMakeLists.txt
View file @
2c2dc4b9
...
@@ -5,7 +5,6 @@ include(transform)
...
@@ -5,7 +5,6 @@ include(transform)
set
(
srcs
set
(
srcs
mutual_information_cpu.cu
mutual_information_cpu.cu
utils.cu
)
)
if
(
NOT FT_WITH_CUDA
)
if
(
NOT FT_WITH_CUDA
)
...
@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs})
...
@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs})
target_link_libraries
(
mutual_information_core PUBLIC
${
TORCH_LIBRARIES
}
)
target_link_libraries
(
mutual_information_core PUBLIC
${
TORCH_LIBRARIES
}
)
# for <torch/extension.h>
# for <torch/extension.h>
target_include_directories
(
mutual_information_core PUBLIC
${
PYTHON_INCLUDE_DIRS
}
)
target_include_directories
(
mutual_information_core PUBLIC
${
PYTHON_INCLUDE_DIRS
}
)
# see https://github.com/NVIDIA/thrust/issues/1401
# and https://github.com/k2-fsa/k2/pull/917
target_compile_definitions
(
mutual_information_core PUBLIC CUB_WRAPPED_NAMESPACE=fast_rnnt
)
target_compile_definitions
(
mutual_information_core PUBLIC THRUST_NS_QUALIFIER=thrust
)
if
(
FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0
)
target_link_libraries
(
mutual_information_core PUBLIC cub
)
endif
()
fast_rnnt/csrc/utils.cu
deleted
100644 → 0
View file @
c268c3d5
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/utils.h"
namespace
fast_rnnt
{
void
MonotonicLowerBound
(
torch
::
Tensor
&
src
)
{
TORCH_CHECK
(
src
.
dim
()
==
1
,
"Only support one dimension tensor"
);
TORCH_CHECK
(
src
.
scalar_type
()
==
torch
::
kLong
,
"Only support LongTensor"
);
TORCH_CHECK
(
src
.
is_contiguous
(),
"Expected to be contiguous"
);
int32_t
dim
=
src
.
numel
();
if
(
src
.
device
().
type
()
==
torch
::
kCPU
)
{
int64_t
min_value
=
std
::
numeric_limits
<
int64_t
>::
max
();
int64_t
*
src_data
=
src
.
data_ptr
<
int64_t
>
();
for
(
int32_t
i
=
dim
-
1
;
i
>=
0
;
--
i
)
{
min_value
=
std
::
min
(
src_data
[
i
],
min_value
);
src
[
i
]
=
min_value
;
}
}
else
{
#ifdef FT_WITH_CUDA
TORCH_CHECK
(
src
.
device
().
is_cuda
());
internal
::
MinOp
<
int64_t
>
min_op
;
auto
src_data
=
src
.
data_ptr
<
int64_t
>
();
internal
::
ConstReversedPtr
<
int64_t
>
src_ptr
=
internal
::
ConstReversedPtr
<
int64_t
>
(
src_data
,
dim
);
internal
::
ReversedPtr
<
int64_t
>
dest_ptr
=
internal
::
ReversedPtr
<
int64_t
>
(
src_data
,
dim
);
// The first time is to determine temporary device storage requirements.
std
::
size_t
temp_storage_bytes
=
0
;
auto
s
=
cub
::
DeviceScan
::
InclusiveScan
(
nullptr
,
temp_storage_bytes
,
src_ptr
,
dest_ptr
,
min_op
,
dim
);
TORCH_CHECK
(
s
==
cudaSuccess
,
cudaGetErrorString
(
s
));
auto
d_temp
=
torch
::
empty
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)},
torch
::
dtype
(
torch
::
kInt8
).
device
(
src
.
device
()));
int8_t
*
d_temp_storage
=
d_temp
.
data_ptr
<
int8_t
>
();
s
=
cub
::
DeviceScan
::
InclusiveScan
(
d_temp_storage
,
temp_storage_bytes
,
src_ptr
,
dest_ptr
,
min_op
,
dim
);
TORCH_CHECK
(
s
==
cudaSuccess
,
cudaGetErrorString
(
s
));
#else
TORCH_CHECK
(
false
,
"Please build with -DFT_WITH_CUDA=ON"
);
#endif // FT_WITH_CUDA
}
}
}
// namespace fast_rnnt
fast_rnnt/csrc/utils.h
deleted
100644 → 0
View file @
c268c3d5
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_UTILS_H_
#define FAST_RNNT_CSRC_UTILS_H_
#include "torch/extension.h"
#ifdef FT_WITH_CUDA
#include "cub/cub.cuh" // NOLINT
namespace
fast_rnnt
{
namespace
internal
{
template
<
typename
T
>
struct
MinOp
{
__host__
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
(
a
>
b
)
?
b
:
a
;
}
};
// Will be used (as both InputIterator and OutputIterator) in
// MonotonicLowerBound to call cub::DeviceScan::InclusiveScan.
template
<
typename
T
>
struct
ConstReversedPtr
{
const
T
*
data
;
// data points to the last element now
explicit
ConstReversedPtr
(
const
T
*
data
,
int32_t
size
)
:
data
(
data
+
size
-
1
)
{}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__
__device__
__forceinline__
const
T
&
operator
[](
int32_t
i
)
const
{
return
data
[
-
i
];
}
__host__
__device__
__forceinline__
ConstReversedPtr
operator
+
(
int32_t
n
)
const
{
ConstReversedPtr
tmp
(
*
this
);
tmp
.
data
-=
n
;
return
tmp
;
}
__host__
__device__
__forceinline__
const
T
&
operator
*
()
const
{
return
*
data
;
}
};
template
<
typename
T
>
struct
ReversedPtr
{
T
*
data
;
// data points to the last element now
explicit
ReversedPtr
(
T
*
data
,
int32_t
size
)
:
data
(
data
+
size
-
1
)
{}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__
__device__
__forceinline__
T
&
operator
[](
int32_t
i
)
{
return
data
[
-
i
];
}
__host__
__device__
__forceinline__
ReversedPtr
operator
+
(
int32_t
n
)
const
{
ReversedPtr
tmp
(
*
this
);
tmp
.
data
-=
n
;
return
tmp
;
}
__host__
__device__
__forceinline__
T
&
operator
*
()
{
return
*
data
;
}
};
}
// namespace internal
}
// namespace fast_rnnt
namespace
std
{
// vaule_type is required by cub::DeviceScan::InclusiveSum
template
<
typename
T
>
struct
iterator_traits
<
fast_rnnt
::
internal
::
ConstReversedPtr
<
T
>>
{
typedef
T
value_type
;
};
template
<
typename
T
>
struct
iterator_traits
<
fast_rnnt
::
internal
::
ReversedPtr
<
T
>>
{
typedef
T
value_type
;
};
}
// namespace std
#endif // FT_WITH_CUDA
namespace
fast_rnnt
{
void
MonotonicLowerBound
(
torch
::
Tensor
&
src
);
}
// namespace fast_rnnt
#endif // FAST_RNNT_CSRC_UTILS_H_
fast_rnnt/python/csrc/CMakeLists.txt
View file @
2c2dc4b9
...
@@ -6,7 +6,6 @@ include(transform)
...
@@ -6,7 +6,6 @@ include(transform)
set
(
fast_rnnt_srcs
set
(
fast_rnnt_srcs
fast_rnnt.cu
fast_rnnt.cu
mutual_information.cu
mutual_information.cu
utils.cu
)
)
if
(
NOT FT_WITH_CUDA
)
if
(
NOT FT_WITH_CUDA
)
...
@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA)
...
@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA)
endif
()
endif
()
pybind11_add_module
(
_fast_rnnt
${
fast_rnnt_srcs
}
)
pybind11_add_module
(
_fast_rnnt
${
fast_rnnt_srcs
}
)
target_link_libraries
(
_fast_rnnt PRIVATE mutual_information_core
)
target_link_libraries
(
_fast_rnnt PRIVATE mutual_information_core
)
...
...
fast_rnnt/python/csrc/fast_rnnt.cu
View file @
2c2dc4b9
...
@@ -20,11 +20,9 @@
...
@@ -20,11 +20,9 @@
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/utils.h"
PYBIND11_MODULE
(
_fast_rnnt
,
m
)
{
PYBIND11_MODULE
(
_fast_rnnt
,
m
)
{
m
.
doc
()
=
"Python wrapper for Fast Rnnt."
;
m
.
doc
()
=
"Python wrapper for Fast Rnnt."
;
fast_rnnt
::
PybindMutualInformation
(
m
);
fast_rnnt
::
PybindMutualInformation
(
m
);
fast_rnnt
::
PybindUtils
(
m
);
}
}
fast_rnnt/python/csrc/mutual_information.cu
View file @
2c2dc4b9
...
@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) {
...
@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) {
},
},
py
::
arg
(
"px"
),
py
::
arg
(
"py"
),
py
::
arg
(
"boundary"
),
py
::
arg
(
"p"
),
py
::
arg
(
"px"
),
py
::
arg
(
"py"
),
py
::
arg
(
"boundary"
),
py
::
arg
(
"p"
),
py
::
arg
(
"ans_grad"
));
py
::
arg
(
"ans_grad"
));
m
.
def
(
"with_cuda"
,
[]()
->
bool
{
#ifdef FT_WITH_CUDA
return
true
;
#else
return
false
;
#endif
});
}
}
}
// namespace fast_rnnt
}
// namespace fast_rnnt
fast_rnnt/python/csrc/utils.cu
deleted
100644 → 0
View file @
c268c3d5
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/utils.h"
#include "fast_rnnt/python/csrc/utils.h"
namespace
fast_rnnt
{
void
PybindUtils
(
py
::
module
&
m
)
{
m
.
def
(
"monotonic_lower_bound_"
,
[](
torch
::
Tensor
&
src
)
->
void
{
DeviceGuard
guard
(
src
.
device
());
if
(
src
.
dim
()
==
1
)
{
MonotonicLowerBound
(
src
);
}
else
if
(
src
.
dim
()
==
2
)
{
int32_t
dim0
=
src
.
sizes
()[
0
];
for
(
int32_t
i
=
0
;
i
<
dim0
;
++
i
)
{
auto
sub
=
src
.
index
({
i
});
MonotonicLowerBound
(
sub
);
}
}
else
{
TORCH_CHECK
(
false
,
"Only support 1 dimension and 2 dimensions tensor"
);
}
},
py
::
arg
(
"src"
));
m
.
def
(
"with_cuda"
,
[]()
->
bool
{
#ifdef FT_WITH_CUDA
return
true
;
#else
return
false
;
#endif
});
}
}
// namespace fast_rnnt
fast_rnnt/python/csrc/utils.h
deleted
100644 → 0
View file @
c268c3d5
/**
* @brief python wrappers for utils.h
*
* @copyright
* Copyright 2022 Xiaomi Corp. (author: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_UTILS_H_
#define FAST_RNNT_PYTHON_CSRC_UTILS_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace
fast_rnnt
{
void
PybindUtils
(
py
::
module
&
m
);
}
// namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_UTILS_H_
fast_rnnt/python/fast_rnnt/__init__.py
View file @
2c2dc4b9
from
_fast_rnnt
import
monotonic_lower_bound_
from
_fast_rnnt
import
with_cuda
from
_fast_rnnt
import
with_cuda
from
.mutual_information
import
mutual_information_recursion
from
.mutual_information
import
mutual_information_recursion
...
...
fast_rnnt/python/fast_rnnt/mutual_information.py
View file @
2c2dc4b9
...
@@ -285,9 +285,10 @@ def mutual_information_recursion(
...
@@ -285,9 +285,10 @@ def mutual_information_recursion(
for
s_begin
,
t_begin
,
s_end
,
t_end
in
boundary
.
tolist
():
for
s_begin
,
t_begin
,
s_end
,
t_end
in
boundary
.
tolist
():
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
assert
0
<=
t_begin
<=
t_end
<=
T
# The following assertions are for efficiency
assert
px
.
is_contiguous
()
# The following statements are for efficiency
assert
py
.
is_contiguous
()
px
,
py
=
px
.
contiguous
(),
py
.
contiguous
()
pxy_grads
=
[
None
,
None
]
pxy_grads
=
[
None
,
None
]
scores
=
MutualInformationRecursionFunction
.
apply
(
px
,
py
,
pxy_grads
,
scores
=
MutualInformationRecursionFunction
.
apply
(
px
,
py
,
pxy_grads
,
boundary
,
return_grad
)
boundary
,
return_grad
)
...
@@ -378,8 +379,9 @@ def joint_mutual_information_recursion(
...
@@ -378,8 +379,9 @@ def joint_mutual_information_recursion(
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
assert
0
<=
t_begin
<=
t_end
<=
T
# The following statements are for efficiency
px_tot
,
py_tot
=
px_tot
.
contiguous
(),
py_tot
.
contiguous
()
px_tot
,
py_tot
=
px_tot
.
contiguous
(),
py_tot
.
contiguous
()
# The following assertions are for efficiency
assert
px_tot
.
ndim
==
3
assert
px_tot
.
ndim
==
3
assert
py_tot
.
ndim
==
3
assert
py_tot
.
ndim
==
3
...
...
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
2c2dc4b9
This diff is collapsed.
Click to expand it.
fast_rnnt/python/tests/rnnt_loss_test.py
View file @
2c2dc4b9
...
@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
assert
px
.
shape
==
(
B
,
S
,
T
+
1
)
assert
px
.
shape
==
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
m
=
fast_rnnt
.
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
None
)
m
=
fast_rnnt
.
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
None
)
if
device
==
torch
.
device
(
"cpu"
):
if
device
==
torch
.
device
(
"cpu"
):
expected
=
-
m
expected
=
-
m
...
@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
boundary_
[:,
3
]
=
frames
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# lm: [B][S+1][C]
# lm: [B][S+1][C]
lm
=
lm_
.
to
(
device
)
lm
=
lm_
.
to
(
device
)
...
@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
assert
(
px
.
shape
==
(
B
,
S
,
T
)
if
rnnt_type
!=
"regular"
else
(
B
,
S
,
T
+
1
)
)
)
assert
px
.
shape
==
(
B
,
S
,
T
)
if
modified
else
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
m
=
fast_rnnt
.
mutual_information_recursion
(
m
=
fast_rnnt
.
mutual_information_recursion
(
...
@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale
=
0.0
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
# compare with torchaudio rnnt_loss
# compare with torchaudio rnnt_loss
if
self
.
has_torch_rnnt_loss
and
not
modified
:
if
self
.
has_torch_rnnt_loss
and
rnnt_type
==
"regular"
:
import
torchaudio.functional
import
torchaudio.functional
m
=
torchaudio
.
functional
.
rnnt_loss
(
m
=
torchaudio
.
functional
.
rnnt_loss
(
...
@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale
=
0.0
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
logits2
)
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
logits2
)
torch_grad
=
torch_grad
[
0
]
torch_grad
=
torch_grad
[
0
]
assert
torch
.
allclose
(
fast_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_rnnt_loss_smoothed
(
self
):
def
test_rnnt_loss_smoothed
(
self
):
B
=
1
B
=
1
...
@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
boundary_
[:,
3
]
=
frames
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# normal rnnt
# normal rnnt
am
=
am_
.
to
(
device
)
am
=
am_
.
to
(
device
)
...
@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
print
(
print
(
f
"Unpruned rnnt loss with
{
rnnt_type
}
rnnt :
{
fast_loss
}
"
)
f
"Unpruned rnnt loss with modified
{
modified
}
:
{
fast_loss
}
"
)
# pruning
# pruning
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
...
@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
return_grad
=
True
,
return_grad
=
True
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
...
@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
s_range
=
r
,
s_range
=
r
,
)
)
# (B, T, r, C)
# (B, T, r, C)
pruned_am
,
pruned_lm
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
pruned_am
,
pruned_lm
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
logits
=
pruned_am
+
pruned_lm
logits
=
pruned_am
+
pruned_lm
...
@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
f
"Pruning loss with range
{
r
}
:
{
pruned_loss
}
"
)
print
(
f
"Pruning loss with range
{
r
}
:
{
pruned_loss
}
"
)
# Test the sequences that only have small number of symbols,
# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
# raise errors (like, nan or inf loss) in our previous versions.
...
@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
print
(
f
"B =
{
B
}
, T =
{
T
}
, S =
{
S
}
, C =
{
C
}
"
)
print
(
f
"B =
{
B
}
, T =
{
T
}
, S =
{
S
}
, C =
{
C
}
"
)
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# normal rnnt
# normal rnnt
am
=
am_
.
to
(
device
)
am
=
am_
.
to
(
device
)
...
@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
print
(
f
"Unpruned rnnt loss with
{
rnnt_type
}
rnnt :
{
loss
}
"
)
f
"Unpruned rnnt loss with modified
{
modified
}
:
{
loss
}
"
)
# pruning
# pruning
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
...
@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
return_grad
=
True
,
return_grad
=
True
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
S0
=
2
S0
=
2
if
modified
:
if
rnnt_type
!=
"regular"
:
S0
=
1
S0
=
1
for
r
in
range
(
S0
,
S
+
2
):
for
r
in
range
(
S0
,
S
+
2
):
...
@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
f
"Pruned loss with range
{
r
}
:
{
pruned_loss
}
"
)
print
(
f
"Pruned loss with range
{
r
}
:
{
pruned_loss
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
setup.py
View file @
2c2dc4b9
...
@@ -36,11 +36,10 @@ class BuildExtension(build_ext):
...
@@ -36,11 +36,10 @@ class BuildExtension(build_ext):
system_make_args
=
os
.
environ
.
get
(
"MAKEFLAGS"
,
""
)
system_make_args
=
os
.
environ
.
get
(
"MAKEFLAGS"
,
""
)
if
cmake_args
==
""
:
if
cmake_args
==
""
:
cmake_args
=
"-DCMAKE_BUILD_TYPE=Release"
cmake_args
=
"-DCMAKE_BUILD_TYPE=Release
-DFT_BUILD_TESTS=OFF
"
if
make_args
==
""
and
system_make_args
==
""
:
if
make_args
==
""
and
system_make_args
==
""
:
print
(
"For fast compilation, run:"
)
make_args
=
' -j '
print
(
'export FT_MAKE_ARGS="-j"; python setup.py install'
)
if
"PYTHON_EXECUTABLE"
not
in
cmake_args
:
if
"PYTHON_EXECUTABLE"
not
in
cmake_args
:
print
(
f
"Setting PYTHON_EXECUTABLE to
{
sys
.
executable
}
"
)
print
(
f
"Setting PYTHON_EXECUTABLE to
{
sys
.
executable
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment