Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
2c2dc4b9
Unverified
Commit
2c2dc4b9
authored
Oct 31, 2022
by
Daniel Povey
Committed by
GitHub
Oct 31, 2022
Browse files
Merge pull request #16 from pkufool/contiguous
Update rnnt_loss
parents
c268c3d5
03b82cc4
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
434 additions
and
478 deletions
+434
-478
CMakeLists.txt
CMakeLists.txt
+1
-5
cmake/cub.cmake
cmake/cub.cmake
+0
-42
fast_rnnt/csrc/CMakeLists.txt
fast_rnnt/csrc/CMakeLists.txt
+0
-8
fast_rnnt/csrc/utils.cu
fast_rnnt/csrc/utils.cu
+0
-63
fast_rnnt/csrc/utils.h
fast_rnnt/csrc/utils.h
+0
-101
fast_rnnt/python/csrc/CMakeLists.txt
fast_rnnt/python/csrc/CMakeLists.txt
+0
-2
fast_rnnt/python/csrc/fast_rnnt.cu
fast_rnnt/python/csrc/fast_rnnt.cu
+0
-2
fast_rnnt/python/csrc/mutual_information.cu
fast_rnnt/python/csrc/mutual_information.cu
+8
-0
fast_rnnt/python/csrc/utils.cu
fast_rnnt/python/csrc/utils.cu
+0
-56
fast_rnnt/python/csrc/utils.h
fast_rnnt/python/csrc/utils.h
+0
-32
fast_rnnt/python/fast_rnnt/__init__.py
fast_rnnt/python/fast_rnnt/__init__.py
+0
-1
fast_rnnt/python/fast_rnnt/mutual_information.py
fast_rnnt/python/fast_rnnt/mutual_information.py
+6
-4
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+379
-129
fast_rnnt/python/tests/rnnt_loss_test.py
fast_rnnt/python/tests/rnnt_loss_test.py
+38
-30
setup.py
setup.py
+2
-3
No files found.
CMakeLists.txt
View file @
2c2dc4b9
...
@@ -46,7 +46,7 @@ message(STATUS "Enabled languages: ${languages}")
...
@@ -46,7 +46,7 @@ message(STATUS "Enabled languages: ${languages}")
project
(
fast_rnnt
${
languages
}
)
project
(
fast_rnnt
${
languages
}
)
set
(
FT_VERSION
"1.
0
"
)
set
(
FT_VERSION
"1.
2
"
)
set
(
ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel
)
set
(
ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel
)
set
(
DEFAULT_BUILD_TYPE
"Release"
)
set
(
DEFAULT_BUILD_TYPE
"Release"
)
...
@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
...
@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include
(
pybind11
)
include
(
pybind11
)
include
(
torch
)
include
(
torch
)
if
(
FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0
)
# CUB is included in CUDA toolkit 11.0 and above
include
(
cub
)
endif
()
if
(
FT_BUILD_TESTS
)
if
(
FT_BUILD_TESTS
)
enable_testing
()
enable_testing
()
...
...
cmake/cub.cmake
deleted
100644 → 0
View file @
c268c3d5
# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function
(
download_cub
)
if
(
CMAKE_VERSION VERSION_LESS 3.11
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_SOURCE_DIR
}
/cmake/Modules
)
endif
()
include
(
FetchContent
)
set
(
cub_URL
"https://github.com/NVlabs/cub/archive/1.15.0.tar.gz"
)
set
(
cub_HASH
"SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685"
)
FetchContent_Declare
(
cub
URL
${
cub_URL
}
URL_HASH
${
cub_HASH
}
)
FetchContent_GetProperties
(
cub
)
if
(
NOT cub
)
message
(
STATUS
"Downloading cub"
)
FetchContent_Populate
(
cub
)
endif
()
message
(
STATUS
"cub is downloaded to
${
cub_SOURCE_DIR
}
"
)
add_library
(
cub INTERFACE
)
target_include_directories
(
cub INTERFACE
${
cub_SOURCE_DIR
}
)
endfunction
()
download_cub
()
fast_rnnt/csrc/CMakeLists.txt
View file @
2c2dc4b9
...
@@ -5,7 +5,6 @@ include(transform)
...
@@ -5,7 +5,6 @@ include(transform)
set
(
srcs
set
(
srcs
mutual_information_cpu.cu
mutual_information_cpu.cu
utils.cu
)
)
if
(
NOT FT_WITH_CUDA
)
if
(
NOT FT_WITH_CUDA
)
...
@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs})
...
@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs})
target_link_libraries
(
mutual_information_core PUBLIC
${
TORCH_LIBRARIES
}
)
target_link_libraries
(
mutual_information_core PUBLIC
${
TORCH_LIBRARIES
}
)
# for <torch/extension.h>
# for <torch/extension.h>
target_include_directories
(
mutual_information_core PUBLIC
${
PYTHON_INCLUDE_DIRS
}
)
target_include_directories
(
mutual_information_core PUBLIC
${
PYTHON_INCLUDE_DIRS
}
)
# see https://github.com/NVIDIA/thrust/issues/1401
# and https://github.com/k2-fsa/k2/pull/917
target_compile_definitions
(
mutual_information_core PUBLIC CUB_WRAPPED_NAMESPACE=fast_rnnt
)
target_compile_definitions
(
mutual_information_core PUBLIC THRUST_NS_QUALIFIER=thrust
)
if
(
FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0
)
target_link_libraries
(
mutual_information_core PUBLIC cub
)
endif
()
fast_rnnt/csrc/utils.cu
deleted
100644 → 0
View file @
c268c3d5
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/utils.h"
namespace
fast_rnnt
{
void
MonotonicLowerBound
(
torch
::
Tensor
&
src
)
{
TORCH_CHECK
(
src
.
dim
()
==
1
,
"Only support one dimension tensor"
);
TORCH_CHECK
(
src
.
scalar_type
()
==
torch
::
kLong
,
"Only support LongTensor"
);
TORCH_CHECK
(
src
.
is_contiguous
(),
"Expected to be contiguous"
);
int32_t
dim
=
src
.
numel
();
if
(
src
.
device
().
type
()
==
torch
::
kCPU
)
{
int64_t
min_value
=
std
::
numeric_limits
<
int64_t
>::
max
();
int64_t
*
src_data
=
src
.
data_ptr
<
int64_t
>
();
for
(
int32_t
i
=
dim
-
1
;
i
>=
0
;
--
i
)
{
min_value
=
std
::
min
(
src_data
[
i
],
min_value
);
src
[
i
]
=
min_value
;
}
}
else
{
#ifdef FT_WITH_CUDA
TORCH_CHECK
(
src
.
device
().
is_cuda
());
internal
::
MinOp
<
int64_t
>
min_op
;
auto
src_data
=
src
.
data_ptr
<
int64_t
>
();
internal
::
ConstReversedPtr
<
int64_t
>
src_ptr
=
internal
::
ConstReversedPtr
<
int64_t
>
(
src_data
,
dim
);
internal
::
ReversedPtr
<
int64_t
>
dest_ptr
=
internal
::
ReversedPtr
<
int64_t
>
(
src_data
,
dim
);
// The first time is to determine temporary device storage requirements.
std
::
size_t
temp_storage_bytes
=
0
;
auto
s
=
cub
::
DeviceScan
::
InclusiveScan
(
nullptr
,
temp_storage_bytes
,
src_ptr
,
dest_ptr
,
min_op
,
dim
);
TORCH_CHECK
(
s
==
cudaSuccess
,
cudaGetErrorString
(
s
));
auto
d_temp
=
torch
::
empty
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)},
torch
::
dtype
(
torch
::
kInt8
).
device
(
src
.
device
()));
int8_t
*
d_temp_storage
=
d_temp
.
data_ptr
<
int8_t
>
();
s
=
cub
::
DeviceScan
::
InclusiveScan
(
d_temp_storage
,
temp_storage_bytes
,
src_ptr
,
dest_ptr
,
min_op
,
dim
);
TORCH_CHECK
(
s
==
cudaSuccess
,
cudaGetErrorString
(
s
));
#else
TORCH_CHECK
(
false
,
"Please build with -DFT_WITH_CUDA=ON"
);
#endif // FT_WITH_CUDA
}
}
}
// namespace fast_rnnt
fast_rnnt/csrc/utils.h
deleted
100644 → 0
View file @
c268c3d5
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_UTILS_H_
#define FAST_RNNT_CSRC_UTILS_H_
#include "torch/extension.h"
#ifdef FT_WITH_CUDA
#include "cub/cub.cuh" // NOLINT
namespace
fast_rnnt
{
namespace
internal
{
template
<
typename
T
>
struct
MinOp
{
__host__
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
(
a
>
b
)
?
b
:
a
;
}
};
// Will be used (as both InputIterator and OutputIterator) in
// MonotonicLowerBound to call cub::DeviceScan::InclusiveScan.
template
<
typename
T
>
struct
ConstReversedPtr
{
const
T
*
data
;
// data points to the last element now
explicit
ConstReversedPtr
(
const
T
*
data
,
int32_t
size
)
:
data
(
data
+
size
-
1
)
{}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__
__device__
__forceinline__
const
T
&
operator
[](
int32_t
i
)
const
{
return
data
[
-
i
];
}
__host__
__device__
__forceinline__
ConstReversedPtr
operator
+
(
int32_t
n
)
const
{
ConstReversedPtr
tmp
(
*
this
);
tmp
.
data
-=
n
;
return
tmp
;
}
__host__
__device__
__forceinline__
const
T
&
operator
*
()
const
{
return
*
data
;
}
};
template
<
typename
T
>
struct
ReversedPtr
{
T
*
data
;
// data points to the last element now
explicit
ReversedPtr
(
T
*
data
,
int32_t
size
)
:
data
(
data
+
size
-
1
)
{}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__
__device__
__forceinline__
T
&
operator
[](
int32_t
i
)
{
return
data
[
-
i
];
}
__host__
__device__
__forceinline__
ReversedPtr
operator
+
(
int32_t
n
)
const
{
ReversedPtr
tmp
(
*
this
);
tmp
.
data
-=
n
;
return
tmp
;
}
__host__
__device__
__forceinline__
T
&
operator
*
()
{
return
*
data
;
}
};
}
// namespace internal
}
// namespace fast_rnnt
namespace
std
{
// vaule_type is required by cub::DeviceScan::InclusiveSum
template
<
typename
T
>
struct
iterator_traits
<
fast_rnnt
::
internal
::
ConstReversedPtr
<
T
>>
{
typedef
T
value_type
;
};
template
<
typename
T
>
struct
iterator_traits
<
fast_rnnt
::
internal
::
ReversedPtr
<
T
>>
{
typedef
T
value_type
;
};
}
// namespace std
#endif // FT_WITH_CUDA
namespace
fast_rnnt
{
void
MonotonicLowerBound
(
torch
::
Tensor
&
src
);
}
// namespace fast_rnnt
#endif // FAST_RNNT_CSRC_UTILS_H_
fast_rnnt/python/csrc/CMakeLists.txt
View file @
2c2dc4b9
...
@@ -6,7 +6,6 @@ include(transform)
...
@@ -6,7 +6,6 @@ include(transform)
set
(
fast_rnnt_srcs
set
(
fast_rnnt_srcs
fast_rnnt.cu
fast_rnnt.cu
mutual_information.cu
mutual_information.cu
utils.cu
)
)
if
(
NOT FT_WITH_CUDA
)
if
(
NOT FT_WITH_CUDA
)
...
@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA)
...
@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA)
endif
()
endif
()
pybind11_add_module
(
_fast_rnnt
${
fast_rnnt_srcs
}
)
pybind11_add_module
(
_fast_rnnt
${
fast_rnnt_srcs
}
)
target_link_libraries
(
_fast_rnnt PRIVATE mutual_information_core
)
target_link_libraries
(
_fast_rnnt PRIVATE mutual_information_core
)
...
...
fast_rnnt/python/csrc/fast_rnnt.cu
View file @
2c2dc4b9
...
@@ -20,11 +20,9 @@
...
@@ -20,11 +20,9 @@
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/utils.h"
PYBIND11_MODULE
(
_fast_rnnt
,
m
)
{
PYBIND11_MODULE
(
_fast_rnnt
,
m
)
{
m
.
doc
()
=
"Python wrapper for Fast Rnnt."
;
m
.
doc
()
=
"Python wrapper for Fast Rnnt."
;
fast_rnnt
::
PybindMutualInformation
(
m
);
fast_rnnt
::
PybindMutualInformation
(
m
);
fast_rnnt
::
PybindUtils
(
m
);
}
}
fast_rnnt/python/csrc/mutual_information.cu
View file @
2c2dc4b9
...
@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) {
...
@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) {
},
},
py
::
arg
(
"px"
),
py
::
arg
(
"py"
),
py
::
arg
(
"boundary"
),
py
::
arg
(
"p"
),
py
::
arg
(
"px"
),
py
::
arg
(
"py"
),
py
::
arg
(
"boundary"
),
py
::
arg
(
"p"
),
py
::
arg
(
"ans_grad"
));
py
::
arg
(
"ans_grad"
));
m
.
def
(
"with_cuda"
,
[]()
->
bool
{
#ifdef FT_WITH_CUDA
return
true
;
#else
return
false
;
#endif
});
}
}
}
// namespace fast_rnnt
}
// namespace fast_rnnt
fast_rnnt/python/csrc/utils.cu
deleted
100644 → 0
View file @
c268c3d5
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/utils.h"
#include "fast_rnnt/python/csrc/utils.h"
namespace
fast_rnnt
{
void
PybindUtils
(
py
::
module
&
m
)
{
m
.
def
(
"monotonic_lower_bound_"
,
[](
torch
::
Tensor
&
src
)
->
void
{
DeviceGuard
guard
(
src
.
device
());
if
(
src
.
dim
()
==
1
)
{
MonotonicLowerBound
(
src
);
}
else
if
(
src
.
dim
()
==
2
)
{
int32_t
dim0
=
src
.
sizes
()[
0
];
for
(
int32_t
i
=
0
;
i
<
dim0
;
++
i
)
{
auto
sub
=
src
.
index
({
i
});
MonotonicLowerBound
(
sub
);
}
}
else
{
TORCH_CHECK
(
false
,
"Only support 1 dimension and 2 dimensions tensor"
);
}
},
py
::
arg
(
"src"
));
m
.
def
(
"with_cuda"
,
[]()
->
bool
{
#ifdef FT_WITH_CUDA
return
true
;
#else
return
false
;
#endif
});
}
}
// namespace fast_rnnt
fast_rnnt/python/csrc/utils.h
deleted
100644 → 0
View file @
c268c3d5
/**
* @brief python wrappers for utils.h
*
* @copyright
* Copyright 2022 Xiaomi Corp. (author: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_UTILS_H_
#define FAST_RNNT_PYTHON_CSRC_UTILS_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace
fast_rnnt
{
void
PybindUtils
(
py
::
module
&
m
);
}
// namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_UTILS_H_
fast_rnnt/python/fast_rnnt/__init__.py
View file @
2c2dc4b9
from
_fast_rnnt
import
monotonic_lower_bound_
from
_fast_rnnt
import
with_cuda
from
_fast_rnnt
import
with_cuda
from
.mutual_information
import
mutual_information_recursion
from
.mutual_information
import
mutual_information_recursion
...
...
fast_rnnt/python/fast_rnnt/mutual_information.py
View file @
2c2dc4b9
...
@@ -285,9 +285,10 @@ def mutual_information_recursion(
...
@@ -285,9 +285,10 @@ def mutual_information_recursion(
for
s_begin
,
t_begin
,
s_end
,
t_end
in
boundary
.
tolist
():
for
s_begin
,
t_begin
,
s_end
,
t_end
in
boundary
.
tolist
():
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
assert
0
<=
t_begin
<=
t_end
<=
T
# The following assertions are for efficiency
assert
px
.
is_contiguous
()
# The following statements are for efficiency
assert
py
.
is_contiguous
()
px
,
py
=
px
.
contiguous
(),
py
.
contiguous
()
pxy_grads
=
[
None
,
None
]
pxy_grads
=
[
None
,
None
]
scores
=
MutualInformationRecursionFunction
.
apply
(
px
,
py
,
pxy_grads
,
scores
=
MutualInformationRecursionFunction
.
apply
(
px
,
py
,
pxy_grads
,
boundary
,
return_grad
)
boundary
,
return_grad
)
...
@@ -378,8 +379,9 @@ def joint_mutual_information_recursion(
...
@@ -378,8 +379,9 @@ def joint_mutual_information_recursion(
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
s_begin
<=
s_end
<=
S
assert
0
<=
t_begin
<=
t_end
<=
T
assert
0
<=
t_begin
<=
t_end
<=
T
# The following statements are for efficiency
px_tot
,
py_tot
=
px_tot
.
contiguous
(),
py_tot
.
contiguous
()
px_tot
,
py_tot
=
px_tot
.
contiguous
(),
py_tot
.
contiguous
()
# The following assertions are for efficiency
assert
px_tot
.
ndim
==
3
assert
px_tot
.
ndim
==
3
assert
py_tot
.
ndim
==
3
assert
py_tot
.
ndim
==
3
...
...
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
2c2dc4b9
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
import
os
import
os
import
fast_rnnt
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
...
@@ -26,13 +25,13 @@ from .mutual_information import mutual_information_recursion
...
@@ -26,13 +25,13 @@ from .mutual_information import mutual_information_recursion
def
fix_for_boundary
(
px
:
Tensor
,
boundary
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
def
fix_for_boundary
(
px
:
Tensor
,
boundary
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
"""
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
Insert -inf's into `px` in appropriate places if `boundary` is not
None. If boundary == None and
modified == False
, px[:,:,-1] will
None. If boundary == None and
rnnt_type == "regular"
, px[:,:,-1] will
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
to be -infinity.
to be -infinity.
Args:
Args:
px: a Tensor of of shape [B][S][T+1] (this function is only
px: a Tensor of of shape [B][S][T+1] (this function is only
called if
modified == False
, see other docs for `
modified
`)
called if
rnnt_type == "regular"
, see other docs for `
rnnt_type
`)
px is modified in-place and returned.
px is modified in-place and returned.
boundary: None, or a Tensor of shape [B][3] containing
boundary: None, or a Tensor of shape [B][3] containing
[s_begin, t_begin, s_end, t_end]; we need only t_end.
[s_begin, t_begin, s_end, t_end]; we need only t_end.
...
@@ -49,8 +48,8 @@ def get_rnnt_logprobs(
...
@@ -49,8 +48,8 @@ def get_rnnt_logprobs(
am
:
Tensor
,
am
:
Tensor
,
symbols
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
rnnt_type
:
str
=
"regular"
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
)
->
Tuple
[
Tensor
,
Tensor
]:
)
->
Tuple
[
Tensor
,
Tensor
]:
"""
"""
Reduces RNN-T problem (the simple case, where joiner network is just
Reduces RNN-T problem (the simple case, where joiner network is just
...
@@ -97,20 +96,32 @@ def get_rnnt_logprobs(
...
@@ -97,20 +96,32 @@ def get_rnnt_logprobs(
[0, 0, S, T]
[0, 0, S, T]
if boundary is not supplied.
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
Returns:
(px, py) (the names are quite arbitrary).
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
py: logprobs, of shape [B][S+1][T]
in the recursion::
in the recursion::
p[b,0,0] = 0.0
p[b,0,0] = 0.0
if
!modified
:
if
rnnt_type == "regular"
:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
if
modified
:
if
rnnt_type != "regular"
:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
.. where p[b][s][t] is the "joint score" of the pair of subsequences
...
@@ -121,21 +132,22 @@ def get_rnnt_logprobs(
...
@@ -121,21 +132,22 @@ def get_rnnt_logprobs(
(s,t) by one in the t direction,
(s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
i.e. of emitting the termination/next-frame symbol.
if
!modified
, px[:,:,T] equals -infinity, meaning on the
if
rnnt_type == "regular"
, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
the probability of the termination symbol on the last frame.
"""
"""
assert
lm
.
ndim
==
3
assert
lm
.
ndim
==
3
,
lm
.
ndim
assert
am
.
ndim
==
3
assert
am
.
ndim
==
3
,
am
.
ndim
assert
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
assert
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
,
(
lm
.
shape
[
0
],
am
.
shape
[
0
])
assert
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
assert
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
,
(
lm
.
shape
[
2
],
am
.
shape
[
2
])
(
B
,
T
,
C
)
=
am
.
shape
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
),
symbols
.
shape
assert
S
>=
1
assert
S
>=
1
,
S
assert
T
>=
S
assert
T
>=
S
,
(
T
,
S
)
assert
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
],
rnnt_type
# subtracting am_max and lm_max is to ensure the probs are in a good range
# subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow.
# to do exp() without causing underflow or overflow.
...
@@ -162,7 +174,7 @@ def get_rnnt_logprobs(
...
@@ -162,7 +174,7 @@ def get_rnnt_logprobs(
-
1
-
1
)
# [B][S][T]
)
# [B][S][T]
if
not
modified
:
if
rnnt_type
==
"regular"
:
px_am
=
torch
.
cat
(
px_am
=
torch
.
cat
(
(
(
px_am
,
px_am
,
...
@@ -189,8 +201,10 @@ def get_rnnt_logprobs(
...
@@ -189,8 +201,10 @@ def get_rnnt_logprobs(
py_lm
=
lm
[:,
:,
termination_symbol
].
unsqueeze
(
2
)
# [B][S+1][1]
py_lm
=
lm
[:,
:,
termination_symbol
].
unsqueeze
(
2
)
# [B][S+1][1]
py
=
py_am
+
py_lm
-
normalizers
py
=
py_am
+
py_lm
-
normalizers
if
not
modified
:
if
rnnt_type
==
"regular"
:
px
=
fix_for_boundary
(
px
,
boundary
)
px
=
fix_for_boundary
(
px
,
boundary
)
elif
rnnt_type
==
"constrained"
:
px
+=
py
[:,
1
:,
:]
return
(
px
,
py
)
return
(
px
,
py
)
...
@@ -201,7 +215,8 @@ def rnnt_loss_simple(
...
@@ -201,7 +215,8 @@ def rnnt_loss_simple(
symbols
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
reduction
:
Optional
[
str
]
=
"mean"
,
return_grad
:
bool
=
False
,
return_grad
:
bool
=
False
,
)
->
Union
[
Tensor
,
Tuple
[
Tensor
,
Tuple
[
Tensor
,
Tensor
]]]:
)
->
Union
[
Tensor
,
Tuple
[
Tensor
,
Tuple
[
Tensor
,
Tensor
]]]:
...
@@ -226,8 +241,23 @@ def rnnt_loss_simple(
...
@@ -226,8 +241,23 @@ def rnnt_loss_simple(
[0, 0, S, T]
[0, 0, S, T]
if boundary is not supplied.
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction:
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`none`: no reduction will be applied.
...
@@ -255,8 +285,24 @@ def rnnt_loss_simple(
...
@@ -255,8 +285,24 @@ def rnnt_loss_simple(
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
rnnt_type
!=
"regular"
else
T0
-
1
if
boundary
is
None
:
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
).
expand
(
B
,
1
,
1
)
else
:
offset
=
(
boundary
[:,
3
]
-
1
)
/
2
penalty
=
offset
.
reshape
(
B
,
1
,
1
)
-
torch
.
arange
(
T0
,
device
=
px
.
device
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
scores_and_grads
=
mutual_information_recursion
(
scores_and_grads
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
,
return_grad
=
return_grad
px
=
px
,
py
=
py
,
boundary
=
boundary
,
return_grad
=
return_grad
)
)
...
@@ -268,9 +314,9 @@ def rnnt_loss_simple(
...
@@ -268,9 +314,9 @@ def rnnt_loss_simple(
elif
reduction
==
"sum"
:
elif
reduction
==
"sum"
:
loss
=
-
torch
.
sum
(
negated_loss
)
loss
=
-
torch
.
sum
(
negated_loss
)
else
:
else
:
assert
(
raise
ValueError
(
False
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
,
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
return
(
loss
,
scores_and_grads
[
1
])
if
return_grad
else
loss
return
(
loss
,
scores_and_grads
[
1
])
if
return_grad
else
loss
...
@@ -279,7 +325,7 @@ def get_rnnt_logprobs_joint(
...
@@ -279,7 +325,7 @@ def get_rnnt_logprobs_joint(
symbols
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
)
->
Tuple
[
Tensor
,
Tensor
]:
)
->
Tuple
[
Tensor
,
Tensor
]:
"""Reduces RNN-T problem to a compact, standard form that can then be given
"""Reduces RNN-T problem to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion().
(with boundaries) to mutual_information_recursion().
...
@@ -300,21 +346,33 @@ def get_rnnt_logprobs_joint(
...
@@ -300,21 +346,33 @@ def get_rnnt_logprobs_joint(
[0, 0, S, T]
[0, 0, S, T]
if boundary is not supplied.
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
Returns:
(px, py) (the names are quite arbitrary)::
(px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1]
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
py: logprobs, of shape [B][S+1][T]
in the recursion::
in the recursion::
p[b,0,0] = 0.0
p[b,0,0] = 0.0
if
!modified
:
if
rnnt_type == "regular"
:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
if
modified
:
if
rnnt_type != "regular"
:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
...
@@ -324,17 +382,18 @@ def get_rnnt_logprobs_joint(
...
@@ -324,17 +382,18 @@ def get_rnnt_logprobs_joint(
of extending the subsequences of length (s,t) by one in the t direction,
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
i.e. of emitting the termination/next-frame symbol.
if
!modified
, px[:,:,T] equals -infinity, meaning on the
if
rnnt_type == "regular"
, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
the probability of the termination symbol on the last frame.
"""
"""
assert
logits
.
ndim
==
4
assert
logits
.
ndim
==
4
,
logits
.
ndim
(
B
,
T
,
S1
,
C
)
=
logits
.
shape
(
B
,
T
,
S1
,
C
)
=
logits
.
shape
S
=
S1
-
1
S
=
S1
-
1
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
),
symbols
.
shape
assert
S
>=
1
assert
S
>=
1
,
S
assert
T
>=
S
assert
T
>=
S
,
(
T
,
S
)
assert
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
],
rnnt_type
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
normalizers
=
normalizers
.
permute
((
0
,
2
,
1
))
normalizers
=
normalizers
.
permute
((
0
,
2
,
1
))
...
@@ -344,7 +403,7 @@ def get_rnnt_logprobs_joint(
...
@@ -344,7 +403,7 @@ def get_rnnt_logprobs_joint(
).
squeeze
(
-
1
)
).
squeeze
(
-
1
)
px
=
px
.
permute
((
0
,
2
,
1
))
px
=
px
.
permute
((
0
,
2
,
1
))
if
not
modified
:
if
rnnt_type
==
"regular"
:
px
=
torch
.
cat
(
px
=
torch
.
cat
(
(
(
px
,
px
,
...
@@ -361,11 +420,11 @@ def get_rnnt_logprobs_joint(
...
@@ -361,11 +420,11 @@ def get_rnnt_logprobs_joint(
logits
[:,
:,
:,
termination_symbol
].
permute
((
0
,
2
,
1
)).
clone
()
logits
[:,
:,
:,
termination_symbol
].
permute
((
0
,
2
,
1
)).
clone
()
)
# [B][S+1][T]
)
# [B][S+1][T]
py
-=
normalizers
py
-=
normalizers
px
=
px
.
contiguous
()
py
=
py
.
contiguous
()
if
not
modified
:
if
rnnt_type
==
"regular"
:
px
=
fix_for_boundary
(
px
,
boundary
)
px
=
fix_for_boundary
(
px
,
boundary
)
elif
rnnt_type
==
"constrained"
:
px
+=
py
[:,
1
:,
:]
return
(
px
,
py
)
return
(
px
,
py
)
...
@@ -375,7 +434,8 @@ def rnnt_loss(
...
@@ -375,7 +434,8 @@ def rnnt_loss(
symbols
:
Tensor
,
symbols
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
reduction
:
Optional
[
str
]
=
"mean"
,
)
->
Tensor
:
)
->
Tensor
:
"""A normal RNN-T loss, which uses a 'joiner' network output as input,
"""A normal RNN-T loss, which uses a 'joiner' network output as input,
...
@@ -395,8 +455,23 @@ def rnnt_loss(
...
@@ -395,8 +455,23 @@ def rnnt_loss(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction:
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`none`: no reduction will be applied.
...
@@ -414,8 +489,24 @@ def rnnt_loss(
...
@@ -414,8 +489,24 @@ def rnnt_loss(
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
rnnt_type
!=
"regular"
else
T0
-
1
if
boundary
is
None
:
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
).
expand
(
B
,
1
,
1
)
else
:
offset
=
(
boundary
[:,
3
]
-
1
)
/
2
penalty
=
offset
.
reshape
(
B
,
1
,
1
)
-
torch
.
arange
(
T0
,
device
=
px
.
device
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
if
reduction
==
"none"
:
if
reduction
==
"none"
:
return
-
negated_loss
return
-
negated_loss
...
@@ -424,30 +515,63 @@ def rnnt_loss(
...
@@ -424,30 +515,63 @@ def rnnt_loss(
elif
reduction
==
"sum"
:
elif
reduction
==
"sum"
:
return
-
torch
.
sum
(
negated_loss
)
return
-
torch
.
sum
(
negated_loss
)
else
:
else
:
assert
(
raise
ValueError
(
False
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
),
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
def
_monotonic_lower_bound
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute a monotonically increasing lower bound of the tensor `x` on the
last dimension. The basic idea is: we traverse the tensor in reverse order,
and update current element with the following statement,
min_value = min(x[i], min_value)
x[i] = min_value
>>> import torch
>>> x = torch.tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32)
>>> _monotonic_lower_bound(x)
tensor([0, 1, 1, 3, 5, 5, 8], dtype=torch.int32)
>>> x
tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32)
>>> x = torch.randint(20, (3, 6), dtype=torch.int32)
>>> x
tensor([[12, 18, 5, 4, 18, 17],
[11, 14, 14, 3, 10, 4],
[19, 3, 8, 13, 7, 19]], dtype=torch.int32)
>>> _monotonic_lower_bound(x)
tensor([[ 4, 4, 4, 4, 17, 17],
[ 3, 3, 3, 3, 4, 4],
[ 3, 3, 7, 7, 7, 19]], dtype=torch.int32)
Args:
x:
The source tensor.
Returns:
Returns a tensor which is monotonic on the last dimension
(i.e. satisfiy `x[i] <= x[i+1]`).
"""
x
=
torch
.
flip
(
x
,
dims
=
(
-
1
,))
x
,
_
=
torch
.
cummin
(
x
,
dim
=-
1
)
x
=
torch
.
flip
(
x
,
dims
=
(
-
1
,))
return
x
def
_adjust_pruning_lower_bound
(
def
_adjust_pruning_lower_bound
(
s_begin
:
torch
.
Tensor
,
s_range
:
int
s_begin
:
torch
.
Tensor
,
s_range
:
int
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Adjust s_begin (pruning lower bound) to make it satisf
ied
the following
"""Adjust s_begin (pruning lower bound
s
) to make it satisf
y
the following
constrains
constrain
t
s
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- start with symbol 0 at first frame.
- start with symbol 0 at first frame.
- s_begin[i + 1] - s_begin[i] < s_range, whic
n
means that we can't skip
- s_begin[i + 1] - s_begin[i] < s_range, whic
h
means that we can't skip
any symbols.
any symbols.
To make it monotonic increasing, we can use `monotonic_lower_bound`
function
To make it monotonic increasing, we can use `
_
monotonic_lower_bound`
above,
in k2,
which guarantee `s_begin[i] <= s_begin[i + 1]`. The main idea is:
which guarantee
s
`s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by
traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`, the initial `min_value` set to
`min_value = min(a_begin[i], min_value)`.
`inf`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constrain is a little tricky. We first transform `s_begin` with
constrain
t
is a little tricky. We first transform `s_begin` with
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
then we make the transformed `s_begin` monotonic increasing, after that,
then we make the transformed `s_begin` monotonic increasing, after that,
we transform back `s_begin` with the same formula as the previous
we transform back `s_begin` with the same formula as the previous
...
@@ -467,21 +591,22 @@ def _adjust_pruning_lower_bound(
...
@@ -467,21 +591,22 @@ def _adjust_pruning_lower_bound(
"""
"""
# s_begin (B, T)
# s_begin (B, T)
(
B
,
T
)
=
s_begin
.
shape
(
B
,
T
)
=
s_begin
.
shape
fast_rnnt
.
monotonic_lower_bound
_
(
s_begin
)
s_begin
=
_
monotonic_lower_bound
(
s_begin
)
# do the magic transformation
# do the magic transformation
s_begin
=
-
(
s_begin
=
-
(
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
)
)
# make the transformed tensor to be non-decreasing
# make the transformed tensor to be non-decreasing
fast_rnnt
.
monotonic_lower_bound
_
(
s_begin
)
s_begin
=
_
monotonic_lower_bound
(
s_begin
)
# make start symbol to be zero.
# make start symbol to be zero.
s_begin
=
torch
.
where
(
s_begin
<
0
,
0
,
s_begin
)
s_begin
=
torch
.
clamp
(
s_begin
,
min
=
0
)
# do the magic transformation again to recover s_begin
# do the magic transformation again to recover s_begin
s_begin
=
-
(
s_begin
=
-
(
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
)
)
return
s_begin
return
s_begin
# To get more insight of how we calculate pruning bounds, please read
# To get more insight of how we calculate pruning bounds, please read
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf)
# (https://arxiv.org/pdf/2206.13236.pdf)
...
@@ -512,9 +637,9 @@ def get_rnnt_prune_ranges(
...
@@ -512,9 +637,9 @@ def get_rnnt_prune_ranges(
Note:
Note:
For the generated tensor ranges(assuming batch size is 1), ranges[:, 0]
For the generated tensor ranges(assuming batch size is 1), ranges[:, 0]
is a monotonic increasing tensor from 0 to `len(symbols)
` and it satisfies
is a monotonic increasing tensor from 0 to `len(symbols)
- s_range` and
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we
won't skip any
it satisfies
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we
symbols.
won't skip any
symbols.
Args:
Args:
px_grad:
px_grad:
...
@@ -529,21 +654,21 @@ def get_rnnt_prune_ranges(
...
@@ -529,21 +654,21 @@ def get_rnnt_prune_ranges(
s_range:
s_range:
How many symbols to keep for each frame.
How many symbols to keep for each frame.
Returns:
Returns:
A tensor contain
s
the
kept symbols indexes for each frame, with shap
e
A tensor
with the shape of (B, T, s_range)
contain
ing
the
indexes of th
e
(B, T, s_range)
.
kept symbols for each frame
.
"""
"""
(
B
,
S
,
T1
)
=
px_grad
.
shape
(
B
,
S
,
T1
)
=
px_grad
.
shape
T
=
py_grad
.
shape
[
-
1
]
T
=
py_grad
.
shape
[
-
1
]
assert
T1
in
[
T
,
T
+
1
]
assert
T1
in
[
T
,
T
+
1
]
,
T1
S1
=
S
+
1
S1
=
S
+
1
assert
py_grad
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py_grad
.
shape
==
(
B
,
S
+
1
,
T
)
,
py_grad
.
shape
assert
boundary
.
shape
==
(
B
,
4
)
assert
boundary
.
shape
==
(
B
,
4
)
,
boundary
.
shape
assert
S
>=
1
assert
S
>=
1
,
S
assert
T
>=
S
assert
T
>=
S
,
(
T
,
S
)
# s_range > S means we won't prune out any symbols. To make indexing with
# s_range > S means we won't prune out any symbols. To make indexing with
# ranges run
s
normally, s_range should be equal to or less than ``S + 1``.
# ranges run normally, s_range should be equal to or less than ``S + 1``.
if
s_range
>
S
:
if
s_range
>
S
:
s_range
=
S
+
1
s_range
=
S
+
1
...
@@ -591,16 +716,17 @@ def get_rnnt_prune_ranges(
...
@@ -591,16 +716,17 @@ def get_rnnt_prune_ranges(
mask
=
mask
<
boundary
[:,
3
].
reshape
(
B
,
1
)
-
1
mask
=
mask
<
boundary
[:,
3
].
reshape
(
B
,
1
)
-
1
s_begin_padding
=
boundary
[:,
2
].
reshape
(
B
,
1
)
-
s_range
+
1
s_begin_padding
=
boundary
[:,
2
].
reshape
(
B
,
1
)
-
s_range
+
1
# handle the cases whe
n
`len(symbols) < s_range`
# handle the cases whe
re
`len(symbols) < s_range`
s_begin_padding
=
torch
.
clamp
(
s_begin_padding
,
min
=
0
)
s_begin_padding
=
torch
.
clamp
(
s_begin_padding
,
min
=
0
)
s_begin
=
torch
.
where
(
mask
,
s_begin
,
s_begin_padding
)
s_begin
=
torch
.
where
(
mask
,
s_begin
,
s_begin_padding
)
# adjusting lower bound to make it satisfied some constrains, see docs in
# adjusting lower bound to make it satisfy some constraints, see docs in
# `_adjust_pruning_lower_bound` for more details of these constrains.
# `_adjust_pruning_lower_bound` for more details of these constraints.
# T1 == T here means we are using the modified version of transducer,
# T1 == T here means we are using the non-regular(i.e. modified rnnt or
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# constrained rnnt) version of transducer, the third constraint becomes
# it only emits one symbol per frame.
# `s_begin[i + 1] - s_begin[i] < 2`, because it only emits one symbol per
# frame.
s_begin
=
_adjust_pruning_lower_bound
(
s_begin
,
2
if
T1
==
T
else
s_range
)
s_begin
=
_adjust_pruning_lower_bound
(
s_begin
,
2
if
T1
==
T
else
s_range
)
ranges
=
s_begin
.
reshape
((
B
,
T
,
1
)).
expand
((
B
,
T
,
s_range
))
+
torch
.
arange
(
ranges
=
s_begin
.
reshape
((
B
,
T
,
1
)).
expand
((
B
,
T
,
s_range
))
+
torch
.
arange
(
...
@@ -613,8 +739,8 @@ def get_rnnt_prune_ranges(
...
@@ -613,8 +739,8 @@ def get_rnnt_prune_ranges(
def
do_rnnt_pruning
(
def
do_rnnt_pruning
(
am
:
torch
.
Tensor
,
lm
:
torch
.
Tensor
,
ranges
:
torch
.
Tensor
am
:
torch
.
Tensor
,
lm
:
torch
.
Tensor
,
ranges
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Prune the output of encoder(am)
output
and prediction network(lm)
"""Prune the output of encoder(am) and prediction network(lm)
with ranges
output of RNNT
.
generated by `get_rnnt_prune_ranges`
.
Args:
Args:
am:
am:
...
@@ -632,9 +758,9 @@ def do_rnnt_pruning(
...
@@ -632,9 +758,9 @@ def do_rnnt_pruning(
# am (B, T, C)
# am (B, T, C)
# lm (B, S + 1, C)
# lm (B, S + 1, C)
# ranges (B, T, s_range)
# ranges (B, T, s_range)
assert
ranges
.
shape
[
0
]
==
am
.
shape
[
0
]
assert
ranges
.
shape
[
0
]
==
am
.
shape
[
0
]
,
(
ranges
.
shape
[
0
],
am
.
shape
[
0
])
assert
ranges
.
shape
[
0
]
==
lm
.
shape
[
0
]
assert
ranges
.
shape
[
0
]
==
lm
.
shape
[
0
]
,
(
ranges
.
shape
[
0
],
lm
.
shape
[
0
])
assert
am
.
shape
[
1
]
==
ranges
.
shape
[
1
]
assert
am
.
shape
[
1
]
==
ranges
.
shape
[
1
]
,
(
am
.
shape
[
1
],
ranges
.
shape
[
1
])
(
B
,
T
,
s_range
)
=
ranges
.
shape
(
B
,
T
,
s_range
)
=
ranges
.
shape
(
B
,
S1
,
C
)
=
lm
.
shape
(
B
,
S1
,
C
)
=
lm
.
shape
S
=
S1
-
1
S
=
S1
-
1
...
@@ -672,9 +798,9 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
...
@@ -672,9 +798,9 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
[ 8, 9, 5, 6, 7],
[ 8, 9, 5, 6, 7],
[12, 13, 14, 10, 11]]])
[12, 13, 14, 10, 11]]])
"""
"""
assert
src
.
dim
()
==
3
assert
src
.
dim
()
==
3
,
src
.
dim
()
(
B
,
T
,
S
)
=
src
.
shape
(
B
,
T
,
S
)
=
src
.
shape
assert
shifts
.
shape
==
(
B
,
T
)
assert
shifts
.
shape
==
(
B
,
T
)
,
shifts
.
shape
index
=
(
index
=
(
torch
.
arange
(
S
,
device
=
src
.
device
)
torch
.
arange
(
S
,
device
=
src
.
device
)
...
@@ -692,7 +818,7 @@ def get_rnnt_logprobs_pruned(
...
@@ -692,7 +818,7 @@ def get_rnnt_logprobs_pruned(
ranges
:
Tensor
,
ranges
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
boundary
:
Tensor
,
boundary
:
Tensor
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
)
->
Tuple
[
Tensor
,
Tensor
]:
)
->
Tuple
[
Tensor
,
Tensor
]:
"""Construct px, py for mutual_information_recursion with pruned output.
"""Construct px, py for mutual_information_recursion with pruned output.
...
@@ -704,6 +830,12 @@ def get_rnnt_logprobs_pruned(
...
@@ -704,6 +830,12 @@ def get_rnnt_logprobs_pruned(
{0..C-1}.
{0..C-1}.
ranges:
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol:
termination_symbol:
the termination symbol, with 0 <= termination_symbol < C
the termination symbol, with 0 <= termination_symbol < C
boundary:
boundary:
...
@@ -712,21 +844,53 @@ def get_rnnt_logprobs_pruned(
...
@@ -712,21 +844,53 @@ def get_rnnt_logprobs_pruned(
[0, 0, S, T]
[0, 0, S, T]
if boundary is not supplied.
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
Returns:
Return the px (B, S, T) if modified else (B, S, T + 1) and
(px, py) (the names are quite arbitrary)::
py (B, S + 1, T) needed by mutual_information_recursion.
px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if `rnnt_type == "regular"`, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
"""
# logits (B, T, s_range, C)
# logits (B, T, s_range, C)
# symbols (B, S)
# symbols (B, S)
# ranges (B, T, s_range)
# ranges (B, T, s_range)
assert
logits
.
ndim
==
4
assert
logits
.
ndim
==
4
,
logits
.
ndim
(
B
,
T
,
s_range
,
C
)
=
logits
.
shape
(
B
,
T
,
s_range
,
C
)
=
logits
.
shape
assert
ranges
.
shape
==
(
B
,
T
,
s_range
)
assert
ranges
.
shape
==
(
B
,
T
,
s_range
)
,
ranges
.
shape
(
B
,
S
)
=
symbols
.
shape
(
B
,
S
)
=
symbols
.
shape
assert
S
>=
1
assert
S
>=
1
,
S
assert
T
>=
S
assert
T
>=
S
,
(
T
,
S
)
assert
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
],
rnnt_type
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
normalizers
=
torch
.
logsumexp
(
logits
,
dim
=
3
)
...
@@ -774,7 +938,7 @@ def get_rnnt_logprobs_pruned(
...
@@ -774,7 +938,7 @@ def get_rnnt_logprobs_pruned(
px
=
px
.
permute
((
0
,
2
,
1
))
px
=
px
.
permute
((
0
,
2
,
1
))
if
not
modified
:
if
rnnt_type
==
"regular"
:
px
=
torch
.
cat
(
px
=
torch
.
cat
(
(
(
px
,
px
,
...
@@ -807,11 +971,10 @@ def get_rnnt_logprobs_pruned(
...
@@ -807,11 +971,10 @@ def get_rnnt_logprobs_pruned(
# (B, S + 1, T)
# (B, S + 1, T)
py
=
py
.
permute
((
0
,
2
,
1
))
py
=
py
.
permute
((
0
,
2
,
1
))
px
=
px
.
contiguous
()
if
rnnt_type
==
"regular"
:
py
=
py
.
contiguous
()
if
not
modified
:
px
=
fix_for_boundary
(
px
,
boundary
)
px
=
fix_for_boundary
(
px
,
boundary
)
elif
rnnt_type
==
"constrained"
:
px
+=
py
[:,
1
:,
:]
return
(
px
,
py
)
return
(
px
,
py
)
...
@@ -822,12 +985,13 @@ def rnnt_loss_pruned(
...
@@ -822,12 +985,13 @@ def rnnt_loss_pruned(
ranges
:
Tensor
,
ranges
:
Tensor
,
termination_symbol
:
int
,
termination_symbol
:
int
,
boundary
:
Tensor
=
None
,
boundary
:
Tensor
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
reduction
:
Optional
[
str
]
=
"mean"
,
)
->
Tensor
:
)
->
Tensor
:
"""A RNN-T loss with pruning, which uses a pruned 'joiner'
network output
"""A RNN-T loss with pruning, which uses
the output of
a pruned 'joiner'
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
network
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
s_range means the symbols
number
kept for each frame.
s_range means the
number of
symbols kept for each frame.
Args:
Args:
logits:
logits:
...
@@ -838,6 +1002,12 @@ def rnnt_loss_pruned(
...
@@ -838,6 +1002,12 @@ def rnnt_loss_pruned(
of the sequence.
of the sequence.
ranges:
ranges:
A tensor containing the symbol ids for each frame that we want to keep.
A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol:
termination_symbol:
The identity of the termination symbol, must be in {0..C-1}
The identity of the termination symbol, must be in {0..C-1}
boundary:
boundary:
...
@@ -845,8 +1015,23 @@ def rnnt_loss_pruned(
...
@@ -845,8 +1015,23 @@ def rnnt_loss_pruned(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction:
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`none`: no reduction will be applied.
...
@@ -854,8 +1039,8 @@ def rnnt_loss_pruned(
...
@@ -854,8 +1039,8 @@ def rnnt_loss_pruned(
`sum`: the output will be summed.
`sum`: the output will be summed.
Default: `mean`
Default: `mean`
Returns:
Returns:
If re
curs
ion is `none`, returns a tensor of shape (B,), containing the
If re
duct
ion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each
element
of the batch, otherwise a scalar
total RNN-T loss values for each
sequence
of the batch, otherwise a scalar
with the reduction applied.
with the reduction applied.
"""
"""
px
,
py
=
get_rnnt_logprobs_pruned
(
px
,
py
=
get_rnnt_logprobs_pruned
(
...
@@ -864,8 +1049,24 @@ def rnnt_loss_pruned(
...
@@ -864,8 +1049,24 @@ def rnnt_loss_pruned(
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
rnnt_type
!=
"regular"
else
T0
-
1
if
boundary
is
None
:
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
).
expand
(
B
,
1
,
1
)
else
:
offset
=
(
boundary
[:,
3
]
-
1
)
/
2
penalty
=
offset
.
reshape
(
B
,
1
,
1
)
-
torch
.
arange
(
T0
,
device
=
px
.
device
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
negated_loss
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
)
if
reduction
==
"none"
:
if
reduction
==
"none"
:
return
-
negated_loss
return
-
negated_loss
...
@@ -874,9 +1075,9 @@ def rnnt_loss_pruned(
...
@@ -874,9 +1075,9 @@ def rnnt_loss_pruned(
elif
reduction
==
"sum"
:
elif
reduction
==
"sum"
:
return
-
torch
.
sum
(
negated_loss
)
return
-
torch
.
sum
(
negated_loss
)
else
:
else
:
assert
(
raise
ValueError
(
False
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
,
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
def
get_rnnt_logprobs_smoothed
(
def
get_rnnt_logprobs_smoothed
(
...
@@ -887,7 +1088,7 @@ def get_rnnt_logprobs_smoothed(
...
@@ -887,7 +1088,7 @@ def get_rnnt_logprobs_smoothed(
lm_only_scale
:
float
=
0.1
,
lm_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
)
->
Tuple
[
Tensor
,
Tensor
]:
)
->
Tuple
[
Tensor
,
Tensor
]:
"""
"""
Reduces RNN-T problem (the simple case, where joiner network is just
Reduces RNN-T problem (the simple case, where joiner network is just
...
@@ -950,18 +1151,32 @@ def get_rnnt_logprobs_smoothed(
...
@@ -950,18 +1151,32 @@ def get_rnnt_logprobs_smoothed(
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame.
also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns:
Returns:
(px, py) (the names are quite arbitrary).
(px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified.
px: logprobs, of shape [B][S][T+1] if rnnt_type == "regular",
[B][S][T] if rnnt_type != "regular".
py: logprobs, of shape [B][S+1][T]
py: logprobs, of shape [B][S+1][T]
in the recursion::
in the recursion::
p[b,0,0] = 0.0
p[b,0,0] = 0.0
if
!modified
:
if
rnnt_type == "regular"
:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
if
modified
:
if
rnnt_type != "regular"
:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences
.. where p[b][s][t] is the "joint score" of the pair of subsequences
...
@@ -976,15 +1191,16 @@ def get_rnnt_logprobs_smoothed(
...
@@ -976,15 +1191,16 @@ def get_rnnt_logprobs_smoothed(
we cannot emit any symbols. This is simply a way of incorporating
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
the probability of the termination symbol on the last frame.
"""
"""
assert
lm
.
ndim
==
3
assert
lm
.
ndim
==
3
,
lm
.
ndim
assert
am
.
ndim
==
3
assert
am
.
ndim
==
3
,
am
.
ndim
assert
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
assert
lm
.
shape
[
0
]
==
am
.
shape
[
0
]
,
(
lm
.
shape
[
0
],
am
.
shape
[
0
])
assert
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
assert
lm
.
shape
[
2
]
==
am
.
shape
[
2
]
,
(
lm
.
shape
[
2
],
am
.
shape
[
2
])
(
B
,
T
,
C
)
=
am
.
shape
(
B
,
T
,
C
)
=
am
.
shape
S
=
lm
.
shape
[
1
]
-
1
S
=
lm
.
shape
[
1
]
-
1
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
),
symbols
.
shape
assert
S
>=
1
assert
S
>=
1
,
S
assert
T
>=
S
assert
T
>=
S
,
(
T
,
S
)
assert
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
],
rnnt_type
# Caution: some parts of this code are a little less clear than they could
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
# be due to optimizations. In particular it may not be totally obvious that
...
@@ -1036,7 +1252,7 @@ def get_rnnt_logprobs_smoothed(
...
@@ -1036,7 +1252,7 @@ def get_rnnt_logprobs_smoothed(
-
1
-
1
)
# [B][S][T]
)
# [B][S][T]
if
not
modified
:
if
rnnt_type
==
"regular"
:
px_am
=
torch
.
cat
(
px_am
=
torch
.
cat
(
(
(
px_am
,
px_am
,
...
@@ -1095,8 +1311,10 @@ def get_rnnt_logprobs_smoothed(
...
@@ -1095,8 +1311,10 @@ def get_rnnt_logprobs_smoothed(
+
py_amonly
*
am_only_scale
+
py_amonly
*
am_only_scale
)
)
if
not
modified
:
if
rnnt_type
==
"regular"
:
px_interp
=
fix_for_boundary
(
px_interp
,
boundary
)
px_interp
=
fix_for_boundary
(
px_interp
,
boundary
)
elif
rnnt_type
==
"constrained"
:
px_interp
+=
py_interp
[:,
1
:,
:]
return
(
px_interp
,
py_interp
)
return
(
px_interp
,
py_interp
)
...
@@ -1109,7 +1327,8 @@ def rnnt_loss_smoothed(
...
@@ -1109,7 +1327,8 @@ def rnnt_loss_smoothed(
lm_only_scale
:
float
=
0.1
,
lm_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
am_only_scale
:
float
=
0.1
,
boundary
:
Optional
[
Tensor
]
=
None
,
boundary
:
Optional
[
Tensor
]
=
None
,
modified
:
bool
=
False
,
rnnt_type
:
str
=
"regular"
,
delay_penalty
:
float
=
0.0
,
reduction
:
Optional
[
str
]
=
"mean"
,
reduction
:
Optional
[
str
]
=
"mean"
,
return_grad
:
bool
=
False
,
return_grad
:
bool
=
False
,
)
->
Union
[
Tuple
[
Tensor
,
Tuple
[
Tensor
,
Tensor
]],
Tensor
]:
)
->
Union
[
Tuple
[
Tensor
,
Tuple
[
Tensor
,
Tensor
]],
Tensor
]:
...
@@ -1141,8 +1360,23 @@ def rnnt_loss_smoothed(
...
@@ -1141,8 +1360,23 @@ def rnnt_loss_smoothed(
[0, 0, S, T]
[0, 0, S, T]
if boundary is not supplied.
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will
rnnt_type:
also be consumed, so at most 1 symbol can appear per frame.
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction:
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`none`: no reduction will be applied.
...
@@ -1173,8 +1407,24 @@ def rnnt_loss_smoothed(
...
@@ -1173,8 +1407,24 @@ def rnnt_loss_smoothed(
lm_only_scale
=
lm_only_scale
,
lm_only_scale
=
lm_only_scale
,
am_only_scale
=
am_only_scale
,
am_only_scale
=
am_only_scale
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
if
delay_penalty
>
0.0
:
B
,
S
,
T0
=
px
.
shape
T
=
T0
if
rnnt_type
!=
"regular"
else
T0
-
1
if
boundary
is
None
:
offset
=
torch
.
tensor
(
(
T
-
1
)
/
2
,
dtype
=
px
.
dtype
,
device
=
px
.
device
,
).
expand
(
B
,
1
,
1
)
else
:
offset
=
(
boundary
[:,
3
]
-
1
)
/
2
penalty
=
offset
.
reshape
(
B
,
1
,
1
)
-
torch
.
arange
(
T0
,
device
=
px
.
device
).
reshape
(
1
,
1
,
T0
)
penalty
=
penalty
*
delay_penalty
px
+=
penalty
.
to
(
px
.
dtype
)
scores_and_grads
=
mutual_information_recursion
(
scores_and_grads
=
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
boundary
,
return_grad
=
return_grad
px
=
px
,
py
=
py
,
boundary
=
boundary
,
return_grad
=
return_grad
)
)
...
@@ -1186,7 +1436,7 @@ def rnnt_loss_smoothed(
...
@@ -1186,7 +1436,7 @@ def rnnt_loss_smoothed(
elif
reduction
==
"sum"
:
elif
reduction
==
"sum"
:
loss
=
-
torch
.
sum
(
negated_loss
)
loss
=
-
torch
.
sum
(
negated_loss
)
else
:
else
:
assert
(
raise
ValueError
(
False
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
,
f
"reduction should be ('none' | 'mean' | 'sum'), given
{
reduction
}
"
)
return
(
loss
,
scores_and_grads
[
1
])
if
return_grad
else
loss
return
(
loss
,
scores_and_grads
[
1
])
if
return_grad
else
loss
fast_rnnt/python/tests/rnnt_loss_test.py
View file @
2c2dc4b9
...
@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
assert
px
.
shape
==
(
B
,
S
,
T
+
1
)
assert
px
.
shape
==
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
m
=
fast_rnnt
.
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
None
)
m
=
fast_rnnt
.
mutual_information_recursion
(
px
=
px
,
py
=
py
,
boundary
=
None
)
if
device
==
torch
.
device
(
"cpu"
):
if
device
==
torch
.
device
(
"cpu"
):
expected
=
-
m
expected
=
-
m
...
@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
boundary_
[:,
3
]
=
frames
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# lm: [B][S+1][C]
# lm: [B][S+1][C]
lm
=
lm_
.
to
(
device
)
lm
=
lm_
.
to
(
device
)
...
@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
assert
(
px
.
shape
==
(
B
,
S
,
T
)
if
rnnt_type
!=
"regular"
else
(
B
,
S
,
T
+
1
)
)
)
assert
px
.
shape
==
(
B
,
S
,
T
)
if
modified
else
(
B
,
S
,
T
+
1
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
py
.
shape
==
(
B
,
S
+
1
,
T
)
assert
symbols
.
shape
==
(
B
,
S
)
assert
symbols
.
shape
==
(
B
,
S
)
m
=
fast_rnnt
.
mutual_information_recursion
(
m
=
fast_rnnt
.
mutual_information_recursion
(
...
@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale
=
0.0
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
# compare with torchaudio rnnt_loss
# compare with torchaudio rnnt_loss
if
self
.
has_torch_rnnt_loss
and
not
modified
:
if
self
.
has_torch_rnnt_loss
and
rnnt_type
==
"regular"
:
import
torchaudio.functional
import
torchaudio.functional
m
=
torchaudio
.
functional
.
rnnt_loss
(
m
=
torchaudio
.
functional
.
rnnt_loss
(
...
@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
termination_symbol
,
termination_symbol
=
termination_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale
=
0.0
,
lm_only_scale
=
0.0
,
am_only_scale
=
0.0
,
am_only_scale
=
0.0
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
assert
torch
.
allclose
(
m
,
expected
.
to
(
device
))
...
@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
logits2
)
torch_grad
=
torch
.
autograd
.
grad
(
torch_loss
,
logits2
)
torch_grad
=
torch_grad
[
0
]
torch_grad
=
torch_grad
[
0
]
assert
torch
.
allclose
(
fast_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_loss
,
torch_loss
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
fast_grad
,
torch_grad
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_rnnt_loss_smoothed
(
self
):
def
test_rnnt_loss_smoothed
(
self
):
B
=
1
B
=
1
...
@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
2
]
=
seq_length
boundary_
[:,
3
]
=
frames
boundary_
[:,
3
]
=
frames
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# normal rnnt
# normal rnnt
am
=
am_
.
to
(
device
)
am
=
am_
.
to
(
device
)
...
@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
)
)
print
(
print
(
f
"Unpruned rnnt loss with
{
rnnt_type
}
rnnt :
{
fast_loss
}
"
)
f
"Unpruned rnnt loss with modified
{
modified
}
:
{
fast_loss
}
"
)
# pruning
# pruning
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
...
@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
return_grad
=
True
,
return_grad
=
True
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
...
@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
s_range
=
r
,
s_range
=
r
,
)
)
# (B, T, r, C)
# (B, T, r, C)
pruned_am
,
pruned_lm
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
pruned_am
,
pruned_lm
=
fast_rnnt
.
do_rnnt_pruning
(
am
=
am
,
lm
=
lm
,
ranges
=
ranges
)
logits
=
pruned_am
+
pruned_lm
logits
=
pruned_am
+
pruned_lm
...
@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
f
"Pruning loss with range
{
r
}
:
{
pruned_loss
}
"
)
print
(
f
"Pruning loss with range
{
r
}
:
{
pruned_loss
}
"
)
# Test the sequences that only have small number of symbols,
# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
# raise errors (like, nan or inf loss) in our previous versions.
...
@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
print
(
f
"B =
{
B
}
, T =
{
T
}
, S =
{
S
}
, C =
{
C
}
"
)
print
(
f
"B =
{
B
}
, T =
{
T
}
, S =
{
S
}
, C =
{
C
}
"
)
for
modified
in
[
True
,
False
]:
for
rnnt_type
in
[
"regular"
,
"modified"
,
"constrained"
]:
for
device
in
self
.
devices
:
for
device
in
self
.
devices
:
# normal rnnt
# normal rnnt
am
=
am_
.
to
(
device
)
am
=
am_
.
to
(
device
)
...
@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
print
(
f
"Unpruned rnnt loss with
{
rnnt_type
}
rnnt :
{
loss
}
"
)
f
"Unpruned rnnt loss with modified
{
modified
}
:
{
loss
}
"
)
# pruning
# pruning
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
...
@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
symbols
=
symbols
,
symbols
=
symbols
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
return_grad
=
True
,
return_grad
=
True
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
S0
=
2
S0
=
2
if
modified
:
if
rnnt_type
!=
"regular"
:
S0
=
1
S0
=
1
for
r
in
range
(
S0
,
S
+
2
):
for
r
in
range
(
S0
,
S
+
2
):
...
@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
ranges
=
ranges
,
ranges
=
ranges
,
termination_symbol
=
terminal_symbol
,
termination_symbol
=
terminal_symbol
,
boundary
=
boundary
,
boundary
=
boundary
,
modified
=
modified
,
rnnt_type
=
rnnt_type
,
reduction
=
"none"
,
reduction
=
"none"
,
)
)
print
(
f
"Pruned loss with range
{
r
}
:
{
pruned_loss
}
"
)
print
(
f
"Pruned loss with range
{
r
}
:
{
pruned_loss
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
setup.py
View file @
2c2dc4b9
...
@@ -36,11 +36,10 @@ class BuildExtension(build_ext):
...
@@ -36,11 +36,10 @@ class BuildExtension(build_ext):
system_make_args
=
os
.
environ
.
get
(
"MAKEFLAGS"
,
""
)
system_make_args
=
os
.
environ
.
get
(
"MAKEFLAGS"
,
""
)
if
cmake_args
==
""
:
if
cmake_args
==
""
:
cmake_args
=
"-DCMAKE_BUILD_TYPE=Release"
cmake_args
=
"-DCMAKE_BUILD_TYPE=Release
-DFT_BUILD_TESTS=OFF
"
if
make_args
==
""
and
system_make_args
==
""
:
if
make_args
==
""
and
system_make_args
==
""
:
print
(
"For fast compilation, run:"
)
make_args
=
' -j '
print
(
'export FT_MAKE_ARGS="-j"; python setup.py install'
)
if
"PYTHON_EXECUTABLE"
not
in
cmake_args
:
if
"PYTHON_EXECUTABLE"
not
in
cmake_args
:
print
(
f
"Setting PYTHON_EXECUTABLE to
{
sys
.
executable
}
"
)
print
(
f
"Setting PYTHON_EXECUTABLE to
{
sys
.
executable
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment