Unverified Commit 2c2dc4b9 authored by Daniel Povey's avatar Daniel Povey Committed by GitHub
Browse files

Merge pull request #16 from pkufool/contiguous

Update rnnt_loss
parents c268c3d5 03b82cc4
...@@ -46,7 +46,7 @@ message(STATUS "Enabled languages: ${languages}") ...@@ -46,7 +46,7 @@ message(STATUS "Enabled languages: ${languages}")
project(fast_rnnt ${languages}) project(fast_rnnt ${languages})
set(FT_VERSION "1.0") set(FT_VERSION "1.2")
set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel) set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel)
set(DEFAULT_BUILD_TYPE "Release") set(DEFAULT_BUILD_TYPE "Release")
...@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) ...@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include(pybind11) include(pybind11)
include(torch) include(torch)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
# CUB is included in CUDA toolkit 11.0 and above
include(cub)
endif()
if(FT_BUILD_TESTS) if(FT_BUILD_TESTS)
enable_testing() enable_testing()
......
# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(download_cub)
if(CMAKE_VERSION VERSION_LESS 3.11)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
include(FetchContent)
set(cub_URL "https://github.com/NVlabs/cub/archive/1.15.0.tar.gz")
set(cub_HASH "SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685")
FetchContent_Declare(cub
URL ${cub_URL}
URL_HASH ${cub_HASH}
)
FetchContent_GetProperties(cub)
if(NOT cub)
message(STATUS "Downloading cub")
FetchContent_Populate(cub)
endif()
message(STATUS "cub is downloaded to ${cub_SOURCE_DIR}")
add_library(cub INTERFACE)
target_include_directories(cub INTERFACE ${cub_SOURCE_DIR})
endfunction()
download_cub()
...@@ -5,7 +5,6 @@ include(transform) ...@@ -5,7 +5,6 @@ include(transform)
set(srcs set(srcs
mutual_information_cpu.cu mutual_information_cpu.cu
utils.cu
) )
if(NOT FT_WITH_CUDA) if(NOT FT_WITH_CUDA)
...@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs}) ...@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs})
target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES}) target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES})
# for <torch/extension.h> # for <torch/extension.h>
target_include_directories(mutual_information_core PUBLIC ${PYTHON_INCLUDE_DIRS}) target_include_directories(mutual_information_core PUBLIC ${PYTHON_INCLUDE_DIRS})
# see https://github.com/NVIDIA/thrust/issues/1401
# and https://github.com/k2-fsa/k2/pull/917
target_compile_definitions(mutual_information_core PUBLIC CUB_WRAPPED_NAMESPACE=fast_rnnt)
target_compile_definitions(mutual_information_core PUBLIC THRUST_NS_QUALIFIER=thrust)
if(FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0)
target_link_libraries(mutual_information_core PUBLIC cub)
endif()
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/utils.h"
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src) {
TORCH_CHECK(src.dim() == 1, "Only support one dimension tensor");
TORCH_CHECK(src.scalar_type() == torch::kLong, "Only support LongTensor");
TORCH_CHECK(src.is_contiguous(), "Expected to be contiguous");
int32_t dim = src.numel();
if (src.device().type() == torch::kCPU) {
int64_t min_value = std::numeric_limits<int64_t>::max();
int64_t *src_data = src.data_ptr<int64_t>();
for (int32_t i = dim - 1; i >= 0; --i) {
min_value = std::min(src_data[i], min_value);
src[i] = min_value;
}
} else {
#ifdef FT_WITH_CUDA
TORCH_CHECK(src.device().is_cuda());
internal::MinOp<int64_t> min_op;
auto src_data = src.data_ptr<int64_t>();
internal::ConstReversedPtr<int64_t> src_ptr =
internal::ConstReversedPtr<int64_t>(src_data, dim);
internal::ReversedPtr<int64_t> dest_ptr =
internal::ReversedPtr<int64_t>(src_data, dim);
// The first time is to determine temporary device storage requirements.
std::size_t temp_storage_bytes = 0;
auto s = cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
auto d_temp = torch::empty({static_cast<int64_t>(temp_storage_bytes)},
torch::dtype(torch::kInt8).device(src.device()));
int8_t *d_temp_storage = d_temp.data_ptr<int8_t>();
s = cub::DeviceScan::InclusiveScan(d_temp_storage, temp_storage_bytes,
src_ptr, dest_ptr, min_op, dim);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
#else
TORCH_CHECK(false, "Please build with -DFT_WITH_CUDA=ON");
#endif // FT_WITH_CUDA
}
}
} // namespace fast_rnnt
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_UTILS_H_
#define FAST_RNNT_CSRC_UTILS_H_
#include "torch/extension.h"
#ifdef FT_WITH_CUDA
#include "cub/cub.cuh" // NOLINT
namespace fast_rnnt {
namespace internal {
template <typename T> struct MinOp {
__host__ __device__ __forceinline__ T operator()(const T &a,
const T &b) const {
return (a > b) ? b : a;
}
};
// Will be used (as both InputIterator and OutputIterator) in
// MonotonicLowerBound to call cub::DeviceScan::InclusiveScan.
template <typename T> struct ConstReversedPtr {
const T *data;
// data points to the last element now
explicit ConstReversedPtr(const T *data, int32_t size)
: data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ const T &operator[](int32_t i) const {
return data[-i];
}
__host__ __device__ __forceinline__ ConstReversedPtr
operator+(int32_t n) const {
ConstReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ const T &operator*() const {
return *data;
}
};
template <typename T> struct ReversedPtr {
T *data;
// data points to the last element now
explicit ReversedPtr(T *data, int32_t size) : data(data + size - 1) {}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__ __device__ __forceinline__ T &operator[](int32_t i) {
return data[-i];
}
__host__ __device__ __forceinline__ ReversedPtr operator+(int32_t n) const {
ReversedPtr tmp(*this);
tmp.data -= n;
return tmp;
}
__host__ __device__ __forceinline__ T &operator*() { return *data; }
};
} // namespace internal
} // namespace fast_rnnt
namespace std {
// vaule_type is required by cub::DeviceScan::InclusiveSum
template <typename T>
struct iterator_traits<fast_rnnt::internal::ConstReversedPtr<T>> {
typedef T value_type;
};
template <typename T>
struct iterator_traits<fast_rnnt::internal::ReversedPtr<T>> {
typedef T value_type;
};
} // namespace std
#endif // FT_WITH_CUDA
namespace fast_rnnt {
void MonotonicLowerBound(torch::Tensor &src);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_UTILS_H_
...@@ -6,7 +6,6 @@ include(transform) ...@@ -6,7 +6,6 @@ include(transform)
set(fast_rnnt_srcs set(fast_rnnt_srcs
fast_rnnt.cu fast_rnnt.cu
mutual_information.cu mutual_information.cu
utils.cu
) )
if(NOT FT_WITH_CUDA) if(NOT FT_WITH_CUDA)
...@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA) ...@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA)
endif() endif()
pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs}) pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs})
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core) target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)
......
...@@ -20,11 +20,9 @@ ...@@ -20,11 +20,9 @@
#include "fast_rnnt/python/csrc/fast_rnnt.h" #include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/mutual_information.h" #include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/utils.h"
PYBIND11_MODULE(_fast_rnnt, m) { PYBIND11_MODULE(_fast_rnnt, m) {
m.doc() = "Python wrapper for Fast Rnnt."; m.doc() = "Python wrapper for Fast Rnnt.";
fast_rnnt::PybindMutualInformation(m); fast_rnnt::PybindMutualInformation(m);
fast_rnnt::PybindUtils(m);
} }
...@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) { ...@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) {
}, },
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"), py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"),
py::arg("ans_grad")); py::arg("ans_grad"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
} }
} // namespace fast_rnnt } // namespace fast_rnnt
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/utils.h"
#include "fast_rnnt/python/csrc/utils.h"
namespace fast_rnnt {
void PybindUtils(py::module &m) {
m.def(
"monotonic_lower_bound_",
[](torch::Tensor &src) -> void {
DeviceGuard guard(src.device());
if (src.dim() == 1) {
MonotonicLowerBound(src);
} else if (src.dim() == 2) {
int32_t dim0 = src.sizes()[0];
for (int32_t i = 0; i < dim0; ++i) {
auto sub = src.index({i});
MonotonicLowerBound(sub);
}
} else {
TORCH_CHECK(false,
"Only support 1 dimension and 2 dimensions tensor");
}
},
py::arg("src"));
m.def("with_cuda", []() -> bool {
#ifdef FT_WITH_CUDA
return true;
#else
return false;
#endif
});
}
} // namespace fast_rnnt
/**
* @brief python wrappers for utils.h
*
* @copyright
* Copyright 2022 Xiaomi Corp. (author: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_UTILS_H_
#define FAST_RNNT_PYTHON_CSRC_UTILS_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace fast_rnnt {
void PybindUtils(py::module &m);
} // namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_UTILS_H_
from _fast_rnnt import monotonic_lower_bound_
from _fast_rnnt import with_cuda from _fast_rnnt import with_cuda
from .mutual_information import mutual_information_recursion from .mutual_information import mutual_information_recursion
......
...@@ -285,9 +285,10 @@ def mutual_information_recursion( ...@@ -285,9 +285,10 @@ def mutual_information_recursion(
for s_begin, t_begin, s_end, t_end in boundary.tolist(): for s_begin, t_begin, s_end, t_end in boundary.tolist():
assert 0 <= s_begin <= s_end <= S assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T assert 0 <= t_begin <= t_end <= T
# The following assertions are for efficiency
assert px.is_contiguous() # The following statements are for efficiency
assert py.is_contiguous() px, py = px.contiguous(), py.contiguous()
pxy_grads = [None, None] pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads, scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad) boundary, return_grad)
...@@ -378,8 +379,9 @@ def joint_mutual_information_recursion( ...@@ -378,8 +379,9 @@ def joint_mutual_information_recursion(
assert 0 <= s_begin <= s_end <= S assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T assert 0 <= t_begin <= t_end <= T
# The following statements are for efficiency
px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous() px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
# The following assertions are for efficiency
assert px_tot.ndim == 3 assert px_tot.ndim == 3
assert py_tot.ndim == 3 assert py_tot.ndim == 3
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import os import os
import fast_rnnt
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -26,13 +25,13 @@ from .mutual_information import mutual_information_recursion ...@@ -26,13 +25,13 @@ from .mutual_information import mutual_information_recursion
def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor: def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
""" """
Insert -inf's into `px` in appropriate places if `boundary` is not Insert -inf's into `px` in appropriate places if `boundary` is not
None. If boundary == None and modified == False, px[:,:,-1] will None. If boundary == None and rnnt_type == "regular", px[:,:,-1] will
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]] be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
to be -infinity. to be -infinity.
Args: Args:
px: a Tensor of of shape [B][S][T+1] (this function is only px: a Tensor of of shape [B][S][T+1] (this function is only
called if modified == False, see other docs for `modified`) called if rnnt_type == "regular", see other docs for `rnnt_type`)
px is modified in-place and returned. px is modified in-place and returned.
boundary: None, or a Tensor of shape [B][3] containing boundary: None, or a Tensor of shape [B][3] containing
[s_begin, t_begin, s_end, t_end]; we need only t_end. [s_begin, t_begin, s_end, t_end]; we need only t_end.
...@@ -49,8 +48,8 @@ def get_rnnt_logprobs( ...@@ -49,8 +48,8 @@ def get_rnnt_logprobs(
am: Tensor, am: Tensor,
symbols: Tensor, symbols: Tensor,
termination_symbol: int, termination_symbol: int,
rnnt_type: str = "regular",
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Reduces RNN-T problem (the simple case, where joiner network is just Reduces RNN-T problem (the simple case, where joiner network is just
...@@ -97,20 +96,32 @@ def get_rnnt_logprobs( ...@@ -97,20 +96,32 @@ def get_rnnt_logprobs(
[0, 0, S, T] [0, 0, S, T]
if boundary is not supplied. if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns: Returns:
(px, py) (the names are quite arbitrary). (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified. px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T] py: logprobs, of shape [B][S+1][T]
in the recursion:: in the recursion::
p[b,0,0] = 0.0 p[b,0,0] = 0.0
if !modified: if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
if modified: if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences .. where p[b][s][t] is the "joint score" of the pair of subsequences
...@@ -121,21 +132,22 @@ def get_rnnt_logprobs( ...@@ -121,21 +132,22 @@ def get_rnnt_logprobs(
(s,t) by one in the t direction, (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol. i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the if rnnt_type == "regular", px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols. "one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating This is simply a way of incorporating
the probability of the termination symbol on the last frame. the probability of the termination symbol on the last frame.
""" """
assert lm.ndim == 3 assert lm.ndim == 3, lm.ndim
assert am.ndim == 3 assert am.ndim == 3, am.ndim
assert lm.shape[0] == am.shape[0] assert lm.shape[0] == am.shape[0], (lm.shape[0], am.shape[0])
assert lm.shape[2] == am.shape[2] assert lm.shape[2] == am.shape[2], (lm.shape[2], am.shape[2])
(B, T, C) = am.shape (B, T, C) = am.shape
S = lm.shape[1] - 1 S = lm.shape[1] - 1
assert symbols.shape == (B, S) assert symbols.shape == (B, S), symbols.shape
assert S >= 1 assert S >= 1, S
assert T >= S assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
# subtracting am_max and lm_max is to ensure the probs are in a good range # subtracting am_max and lm_max is to ensure the probs are in a good range
# to do exp() without causing underflow or overflow. # to do exp() without causing underflow or overflow.
...@@ -162,7 +174,7 @@ def get_rnnt_logprobs( ...@@ -162,7 +174,7 @@ def get_rnnt_logprobs(
-1 -1
) # [B][S][T] ) # [B][S][T]
if not modified: if rnnt_type == "regular":
px_am = torch.cat( px_am = torch.cat(
( (
px_am, px_am,
...@@ -189,8 +201,10 @@ def get_rnnt_logprobs( ...@@ -189,8 +201,10 @@ def get_rnnt_logprobs(
py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1] py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers py = py_am + py_lm - normalizers
if not modified: if rnnt_type == "regular":
px = fix_for_boundary(px, boundary) px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py) return (px, py)
...@@ -201,7 +215,8 @@ def rnnt_loss_simple( ...@@ -201,7 +215,8 @@ def rnnt_loss_simple(
symbols: Tensor, symbols: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
return_grad: bool = False, return_grad: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: ) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]:
...@@ -226,8 +241,23 @@ def rnnt_loss_simple( ...@@ -226,8 +241,23 @@ def rnnt_loss_simple(
[0, 0, S, T] [0, 0, S, T]
if boundary is not supplied. if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction: reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied. `none`: no reduction will be applied.
...@@ -255,8 +285,24 @@ def rnnt_loss_simple( ...@@ -255,8 +285,24 @@ def rnnt_loss_simple(
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
scores_and_grads = mutual_information_recursion( scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad px=px, py=py, boundary=boundary, return_grad=return_grad
) )
...@@ -268,9 +314,9 @@ def rnnt_loss_simple( ...@@ -268,9 +314,9 @@ def rnnt_loss_simple(
elif reduction == "sum": elif reduction == "sum":
loss = -torch.sum(negated_loss) loss = -torch.sum(negated_loss)
else: else:
assert ( raise ValueError(
False f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" )
return (loss, scores_and_grads[1]) if return_grad else loss return (loss, scores_and_grads[1]) if return_grad else loss
...@@ -279,7 +325,7 @@ def get_rnnt_logprobs_joint( ...@@ -279,7 +325,7 @@ def get_rnnt_logprobs_joint(
symbols: Tensor, symbols: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Reduces RNN-T problem to a compact, standard form that can then be given """Reduces RNN-T problem to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion(). (with boundaries) to mutual_information_recursion().
...@@ -300,21 +346,33 @@ def get_rnnt_logprobs_joint( ...@@ -300,21 +346,33 @@ def get_rnnt_logprobs_joint(
[0, 0, S, T] [0, 0, S, T]
if boundary is not supplied. if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns: Returns:
(px, py) (the names are quite arbitrary):: (px, py) (the names are quite arbitrary)::
px: logprobs, of shape [B][S][T+1] px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T] py: logprobs, of shape [B][S+1][T]
in the recursion:: in the recursion::
p[b,0,0] = 0.0 p[b,0,0] = 0.0
if !modified: if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
if modified: if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of .. where p[b][s][t] is the "joint score" of the pair of subsequences of
...@@ -324,17 +382,18 @@ def get_rnnt_logprobs_joint( ...@@ -324,17 +382,18 @@ def get_rnnt_logprobs_joint(
of extending the subsequences of length (s,t) by one in the t direction, of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol. i.e. of emitting the termination/next-frame symbol.
if !modified, px[:,:,T] equals -infinity, meaning on the if rnnt_type == "regular", px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols. "one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating This is simply a way of incorporating
the probability of the termination symbol on the last frame. the probability of the termination symbol on the last frame.
""" """
assert logits.ndim == 4 assert logits.ndim == 4, logits.ndim
(B, T, S1, C) = logits.shape (B, T, S1, C) = logits.shape
S = S1 - 1 S = S1 - 1
assert symbols.shape == (B, S) assert symbols.shape == (B, S), symbols.shape
assert S >= 1 assert S >= 1, S
assert T >= S assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
normalizers = torch.logsumexp(logits, dim=3) normalizers = torch.logsumexp(logits, dim=3)
normalizers = normalizers.permute((0, 2, 1)) normalizers = normalizers.permute((0, 2, 1))
...@@ -344,7 +403,7 @@ def get_rnnt_logprobs_joint( ...@@ -344,7 +403,7 @@ def get_rnnt_logprobs_joint(
).squeeze(-1) ).squeeze(-1)
px = px.permute((0, 2, 1)) px = px.permute((0, 2, 1))
if not modified: if rnnt_type == "regular":
px = torch.cat( px = torch.cat(
( (
px, px,
...@@ -361,11 +420,11 @@ def get_rnnt_logprobs_joint( ...@@ -361,11 +420,11 @@ def get_rnnt_logprobs_joint(
logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone()
) # [B][S+1][T] ) # [B][S+1][T]
py -= normalizers py -= normalizers
px = px.contiguous()
py = py.contiguous()
if not modified: if rnnt_type == "regular":
px = fix_for_boundary(px, boundary) px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py) return (px, py)
...@@ -375,7 +434,8 @@ def rnnt_loss( ...@@ -375,7 +434,8 @@ def rnnt_loss(
symbols: Tensor, symbols: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
) -> Tensor: ) -> Tensor:
"""A normal RNN-T loss, which uses a 'joiner' network output as input, """A normal RNN-T loss, which uses a 'joiner' network output as input,
...@@ -395,8 +455,23 @@ def rnnt_loss( ...@@ -395,8 +455,23 @@ def rnnt_loss(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied. [0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction: reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied. `none`: no reduction will be applied.
...@@ -414,8 +489,24 @@ def rnnt_loss( ...@@ -414,8 +489,24 @@ def rnnt_loss(
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none": if reduction == "none":
return -negated_loss return -negated_loss
...@@ -424,30 +515,63 @@ def rnnt_loss( ...@@ -424,30 +515,63 @@ def rnnt_loss(
elif reduction == "sum": elif reduction == "sum":
return -torch.sum(negated_loss) return -torch.sum(negated_loss)
else: else:
assert ( raise ValueError(
False f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" )
def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor:
"""Compute a monotonically increasing lower bound of the tensor `x` on the
last dimension. The basic idea is: we traverse the tensor in reverse order,
and update current element with the following statement,
min_value = min(x[i], min_value)
x[i] = min_value
>>> import torch
>>> x = torch.tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32)
>>> _monotonic_lower_bound(x)
tensor([0, 1, 1, 3, 5, 5, 8], dtype=torch.int32)
>>> x
tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32)
>>> x = torch.randint(20, (3, 6), dtype=torch.int32)
>>> x
tensor([[12, 18, 5, 4, 18, 17],
[11, 14, 14, 3, 10, 4],
[19, 3, 8, 13, 7, 19]], dtype=torch.int32)
>>> _monotonic_lower_bound(x)
tensor([[ 4, 4, 4, 4, 17, 17],
[ 3, 3, 3, 3, 4, 4],
[ 3, 3, 7, 7, 7, 19]], dtype=torch.int32)
Args:
x:
The source tensor.
Returns:
Returns a tensor which is monotonic on the last dimension
(i.e. satisfiy `x[i] <= x[i+1]`).
"""
x = torch.flip(x, dims=(-1,))
x, _ = torch.cummin(x, dim=-1)
x = torch.flip(x, dims=(-1,))
return x
def _adjust_pruning_lower_bound( def _adjust_pruning_lower_bound(
s_begin: torch.Tensor, s_range: int s_begin: torch.Tensor, s_range: int
) -> torch.Tensor: ) -> torch.Tensor:
"""Adjust s_begin (pruning lower bound) to make it satisfied the following """Adjust s_begin (pruning lower bounds) to make it satisfy the following
constrains constraints
- monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1] - monotonic increasing, i.e. s_begin[i] <= s_begin[i + 1]
- start with symbol 0 at first frame. - start with symbol 0 at first frame.
- s_begin[i + 1] - s_begin[i] < s_range, whicn means that we can't skip - s_begin[i + 1] - s_begin[i] < s_range, which means that we can't skip
any symbols. any symbols.
To make it monotonic increasing, we can use `monotonic_lower_bound` function To make it monotonic increasing, we can use `_monotonic_lower_bound` above,
in k2, which guarantee `s_begin[i] <= s_begin[i + 1]`. The main idea is: which guarantees `s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`, the initial `min_value` set to `min_value = min(a_begin[i], min_value)`.
`inf`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range` The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constrain is a little tricky. We first transform `s_begin` with constraint is a little tricky. We first transform `s_begin` with
`s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))` `s_begin = -(s_begin - (s_range - 1) * torch.arange(0,T))`
then we make the transformed `s_begin` monotonic increasing, after that, then we make the transformed `s_begin` monotonic increasing, after that,
we transform back `s_begin` with the same formula as the previous we transform back `s_begin` with the same formula as the previous
...@@ -467,21 +591,22 @@ def _adjust_pruning_lower_bound( ...@@ -467,21 +591,22 @@ def _adjust_pruning_lower_bound(
""" """
# s_begin (B, T) # s_begin (B, T)
(B, T) = s_begin.shape (B, T) = s_begin.shape
fast_rnnt.monotonic_lower_bound_(s_begin) s_begin = _monotonic_lower_bound(s_begin)
# do the magic transformation # do the magic transformation
s_begin = -( s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device) s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
) )
# make the transformed tensor to be non-decreasing # make the transformed tensor to be non-decreasing
fast_rnnt.monotonic_lower_bound_(s_begin) s_begin = _monotonic_lower_bound(s_begin)
# make start symbol to be zero. # make start symbol to be zero.
s_begin = torch.where(s_begin < 0, 0, s_begin) s_begin = torch.clamp(s_begin, min=0)
# do the magic transformation again to recover s_begin # do the magic transformation again to recover s_begin
s_begin = -( s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device) s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
) )
return s_begin return s_begin
# To get more insight of how we calculate pruning bounds, please read # To get more insight of how we calculate pruning bounds, please read
# chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper # chapter 3.2 (Pruning bounds) of our Pruned RNN-T paper
# (https://arxiv.org/pdf/2206.13236.pdf) # (https://arxiv.org/pdf/2206.13236.pdf)
...@@ -512,9 +637,9 @@ def get_rnnt_prune_ranges( ...@@ -512,9 +637,9 @@ def get_rnnt_prune_ranges(
Note: Note:
For the generated tensor ranges(assuming batch size is 1), ranges[:, 0] For the generated tensor ranges(assuming batch size is 1), ranges[:, 0]
is a monotonic increasing tensor from 0 to `len(symbols)` and it satisfies is a monotonic increasing tensor from 0 to `len(symbols) - s_range` and
`ranges[t+1, 0] - ranges[t, 0] < s_range` which means we won't skip any it satisfies `ranges[t+1, 0] - ranges[t, 0] < s_range` which means we
symbols. won't skip any symbols.
Args: Args:
px_grad: px_grad:
...@@ -529,21 +654,21 @@ def get_rnnt_prune_ranges( ...@@ -529,21 +654,21 @@ def get_rnnt_prune_ranges(
s_range: s_range:
How many symbols to keep for each frame. How many symbols to keep for each frame.
Returns: Returns:
A tensor contains the kept symbols indexes for each frame, with shape A tensor with the shape of (B, T, s_range) containing the indexes of the
(B, T, s_range). kept symbols for each frame.
""" """
(B, S, T1) = px_grad.shape (B, S, T1) = px_grad.shape
T = py_grad.shape[-1] T = py_grad.shape[-1]
assert T1 in [T, T + 1] assert T1 in [T, T + 1], T1
S1 = S + 1 S1 = S + 1
assert py_grad.shape == (B, S + 1, T) assert py_grad.shape == (B, S + 1, T), py_grad.shape
assert boundary.shape == (B, 4) assert boundary.shape == (B, 4), boundary.shape
assert S >= 1 assert S >= 1, S
assert T >= S assert T >= S, (T, S)
# s_range > S means we won't prune out any symbols. To make indexing with # s_range > S means we won't prune out any symbols. To make indexing with
# ranges runs normally, s_range should be equal to or less than ``S + 1``. # ranges run normally, s_range should be equal to or less than ``S + 1``.
if s_range > S: if s_range > S:
s_range = S + 1 s_range = S + 1
...@@ -591,16 +716,17 @@ def get_rnnt_prune_ranges( ...@@ -591,16 +716,17 @@ def get_rnnt_prune_ranges(
mask = mask < boundary[:, 3].reshape(B, 1) - 1 mask = mask < boundary[:, 3].reshape(B, 1) - 1
s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1 s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1
# handle the cases when `len(symbols) < s_range` # handle the cases where `len(symbols) < s_range`
s_begin_padding = torch.clamp(s_begin_padding, min=0) s_begin_padding = torch.clamp(s_begin_padding, min=0)
s_begin = torch.where(mask, s_begin, s_begin_padding) s_begin = torch.where(mask, s_begin, s_begin_padding)
# adjusting lower bound to make it satisfied some constrains, see docs in # adjusting lower bound to make it satisfy some constraints, see docs in
# `_adjust_pruning_lower_bound` for more details of these constrains. # `_adjust_pruning_lower_bound` for more details of these constraints.
# T1 == T here means we are using the modified version of transducer, # T1 == T here means we are using the non-regular(i.e. modified rnnt or
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because # constrained rnnt) version of transducer, the third constraint becomes
# it only emits one symbol per frame. # `s_begin[i + 1] - s_begin[i] < 2`, because it only emits one symbol per
# frame.
s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range) s_begin = _adjust_pruning_lower_bound(s_begin, 2 if T1 == T else s_range)
ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange( ranges = s_begin.reshape((B, T, 1)).expand((B, T, s_range)) + torch.arange(
...@@ -613,8 +739,8 @@ def get_rnnt_prune_ranges( ...@@ -613,8 +739,8 @@ def get_rnnt_prune_ranges(
def do_rnnt_pruning( def do_rnnt_pruning(
am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor am: torch.Tensor, lm: torch.Tensor, ranges: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Prune the output of encoder(am) output and prediction network(lm) """Prune the output of encoder(am) and prediction network(lm) with ranges
output of RNNT. generated by `get_rnnt_prune_ranges`.
Args: Args:
am: am:
...@@ -632,9 +758,9 @@ def do_rnnt_pruning( ...@@ -632,9 +758,9 @@ def do_rnnt_pruning(
# am (B, T, C) # am (B, T, C)
# lm (B, S + 1, C) # lm (B, S + 1, C)
# ranges (B, T, s_range) # ranges (B, T, s_range)
assert ranges.shape[0] == am.shape[0] assert ranges.shape[0] == am.shape[0], (ranges.shape[0], am.shape[0])
assert ranges.shape[0] == lm.shape[0] assert ranges.shape[0] == lm.shape[0], (ranges.shape[0], lm.shape[0])
assert am.shape[1] == ranges.shape[1] assert am.shape[1] == ranges.shape[1], (am.shape[1], ranges.shape[1])
(B, T, s_range) = ranges.shape (B, T, s_range) = ranges.shape
(B, S1, C) = lm.shape (B, S1, C) = lm.shape
S = S1 - 1 S = S1 - 1
...@@ -672,9 +798,9 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): ...@@ -672,9 +798,9 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
[ 8, 9, 5, 6, 7], [ 8, 9, 5, 6, 7],
[12, 13, 14, 10, 11]]]) [12, 13, 14, 10, 11]]])
""" """
assert src.dim() == 3 assert src.dim() == 3, src.dim()
(B, T, S) = src.shape (B, T, S) = src.shape
assert shifts.shape == (B, T) assert shifts.shape == (B, T), shifts.shape
index = ( index = (
torch.arange(S, device=src.device) torch.arange(S, device=src.device)
...@@ -692,7 +818,7 @@ def get_rnnt_logprobs_pruned( ...@@ -692,7 +818,7 @@ def get_rnnt_logprobs_pruned(
ranges: Tensor, ranges: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Tensor, boundary: Tensor,
modified: bool = False, rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output. """Construct px, py for mutual_information_recursion with pruned output.
...@@ -704,6 +830,12 @@ def get_rnnt_logprobs_pruned( ...@@ -704,6 +830,12 @@ def get_rnnt_logprobs_pruned(
{0..C-1}. {0..C-1}.
ranges: ranges:
A tensor containing the symbol ids for each frame that we want to keep. A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol: termination_symbol:
the termination symbol, with 0 <= termination_symbol < C the termination symbol, with 0 <= termination_symbol < C
boundary: boundary:
...@@ -712,21 +844,53 @@ def get_rnnt_logprobs_pruned( ...@@ -712,21 +844,53 @@ def get_rnnt_logprobs_pruned(
[0, 0, S, T] [0, 0, S, T]
if boundary is not supplied. if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns: Returns:
Return the px (B, S, T) if modified else (B, S, T + 1) and (px, py) (the names are quite arbitrary)::
py (B, S + 1, T) needed by mutual_information_recursion. px: logprobs, of shape [B][S][T+1] if rnnt_type is regular,
[B][S][T] if rnnt_type is not regular.
py: logprobs, of shape [B][S+1][T]
in the recursion::
p[b,0,0] = 0.0
if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
if `rnnt_type == "regular"`, px[:,:,T] equals -infinity, meaning on the
"one-past-the-last" frame we cannot emit any symbols.
This is simply a way of incorporating
the probability of the termination symbol on the last frame.
""" """
# logits (B, T, s_range, C) # logits (B, T, s_range, C)
# symbols (B, S) # symbols (B, S)
# ranges (B, T, s_range) # ranges (B, T, s_range)
assert logits.ndim == 4 assert logits.ndim == 4, logits.ndim
(B, T, s_range, C) = logits.shape (B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range) assert ranges.shape == (B, T, s_range), ranges.shape
(B, S) = symbols.shape (B, S) = symbols.shape
assert S >= 1 assert S >= 1, S
assert T >= S assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
normalizers = torch.logsumexp(logits, dim=3) normalizers = torch.logsumexp(logits, dim=3)
...@@ -774,7 +938,7 @@ def get_rnnt_logprobs_pruned( ...@@ -774,7 +938,7 @@ def get_rnnt_logprobs_pruned(
px = px.permute((0, 2, 1)) px = px.permute((0, 2, 1))
if not modified: if rnnt_type == "regular":
px = torch.cat( px = torch.cat(
( (
px, px,
...@@ -807,11 +971,10 @@ def get_rnnt_logprobs_pruned( ...@@ -807,11 +971,10 @@ def get_rnnt_logprobs_pruned(
# (B, S + 1, T) # (B, S + 1, T)
py = py.permute((0, 2, 1)) py = py.permute((0, 2, 1))
px = px.contiguous() if rnnt_type == "regular":
py = py.contiguous()
if not modified:
px = fix_for_boundary(px, boundary) px = fix_for_boundary(px, boundary)
elif rnnt_type == "constrained":
px += py[:, 1:, :]
return (px, py) return (px, py)
...@@ -822,12 +985,13 @@ def rnnt_loss_pruned( ...@@ -822,12 +985,13 @@ def rnnt_loss_pruned(
ranges: Tensor, ranges: Tensor,
termination_symbol: int, termination_symbol: int,
boundary: Tensor = None, boundary: Tensor = None,
modified: bool = False, rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
) -> Tensor: ) -> Tensor:
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output """A RNN-T loss with pruning, which uses the output of a pruned 'joiner'
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
s_range means the symbols number kept for each frame. s_range means the number of symbols kept for each frame.
Args: Args:
logits: logits:
...@@ -838,6 +1002,12 @@ def rnnt_loss_pruned( ...@@ -838,6 +1002,12 @@ def rnnt_loss_pruned(
of the sequence. of the sequence.
ranges: ranges:
A tensor containing the symbol ids for each frame that we want to keep. A tensor containing the symbol ids for each frame that we want to keep.
It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]``
contains the begin symbol ``0 <= s <= S - s_range + 1``, such that
``logits[b,t,:,:]`` represents the logits with positions
``s, s + 1, ... s + s_range - 1``.
See docs in :func:`get_rnnt_prune_ranges` for more details of what
ranges contains.
termination_symbol: termination_symbol:
The identity of the termination symbol, must be in {0..C-1} The identity of the termination symbol, must be in {0..C-1}
boundary: boundary:
...@@ -845,8 +1015,23 @@ def rnnt_loss_pruned( ...@@ -845,8 +1015,23 @@ def rnnt_loss_pruned(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied. [0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction: reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied. `none`: no reduction will be applied.
...@@ -854,8 +1039,8 @@ def rnnt_loss_pruned( ...@@ -854,8 +1039,8 @@ def rnnt_loss_pruned(
`sum`: the output will be summed. `sum`: the output will be summed.
Default: `mean` Default: `mean`
Returns: Returns:
If recursion is `none`, returns a tensor of shape (B,), containing the If reduction is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar total RNN-T loss values for each sequence of the batch, otherwise a scalar
with the reduction applied. with the reduction applied.
""" """
px, py = get_rnnt_logprobs_pruned( px, py = get_rnnt_logprobs_pruned(
...@@ -864,8 +1049,24 @@ def rnnt_loss_pruned( ...@@ -864,8 +1049,24 @@ def rnnt_loss_pruned(
ranges=ranges, ranges=ranges,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none": if reduction == "none":
return -negated_loss return -negated_loss
...@@ -874,9 +1075,9 @@ def rnnt_loss_pruned( ...@@ -874,9 +1075,9 @@ def rnnt_loss_pruned(
elif reduction == "sum": elif reduction == "sum":
return -torch.sum(negated_loss) return -torch.sum(negated_loss)
else: else:
assert ( raise ValueError(
False f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" )
def get_rnnt_logprobs_smoothed( def get_rnnt_logprobs_smoothed(
...@@ -887,7 +1088,7 @@ def get_rnnt_logprobs_smoothed( ...@@ -887,7 +1088,7 @@ def get_rnnt_logprobs_smoothed(
lm_only_scale: float = 0.1, lm_only_scale: float = 0.1,
am_only_scale: float = 0.1, am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, rnnt_type: str = "regular",
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Reduces RNN-T problem (the simple case, where joiner network is just Reduces RNN-T problem (the simple case, where joiner network is just
...@@ -950,18 +1151,32 @@ def get_rnnt_logprobs_smoothed( ...@@ -950,18 +1151,32 @@ def get_rnnt_logprobs_smoothed(
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will modified: if True, each time a real symbol is consumed a frame will
also be consumed, so at most 1 symbol can appear per frame. also be consumed, so at most 1 symbol can appear per frame.
rnnt_type:
Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
Returns: Returns:
(px, py) (the names are quite arbitrary). (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1] if !modified, [B][S][T] if modified. px: logprobs, of shape [B][S][T+1] if rnnt_type == "regular",
[B][S][T] if rnnt_type != "regular".
py: logprobs, of shape [B][S+1][T] py: logprobs, of shape [B][S+1][T]
in the recursion:: in the recursion::
p[b,0,0] = 0.0 p[b,0,0] = 0.0
if !modified: if rnnt_type == "regular":
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
if modified: if rnnt_type != "regular":
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1]) p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences .. where p[b][s][t] is the "joint score" of the pair of subsequences
...@@ -976,15 +1191,16 @@ def get_rnnt_logprobs_smoothed( ...@@ -976,15 +1191,16 @@ def get_rnnt_logprobs_smoothed(
we cannot emit any symbols. This is simply a way of incorporating we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame. the probability of the termination symbol on the last frame.
""" """
assert lm.ndim == 3 assert lm.ndim == 3, lm.ndim
assert am.ndim == 3 assert am.ndim == 3, am.ndim
assert lm.shape[0] == am.shape[0] assert lm.shape[0] == am.shape[0], (lm.shape[0], am.shape[0])
assert lm.shape[2] == am.shape[2] assert lm.shape[2] == am.shape[2], (lm.shape[2], am.shape[2])
(B, T, C) = am.shape (B, T, C) = am.shape
S = lm.shape[1] - 1 S = lm.shape[1] - 1
assert symbols.shape == (B, S) assert symbols.shape == (B, S), symbols.shape
assert S >= 1 assert S >= 1, S
assert T >= S assert T >= S, (T, S)
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
# Caution: some parts of this code are a little less clear than they could # Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that # be due to optimizations. In particular it may not be totally obvious that
...@@ -1036,7 +1252,7 @@ def get_rnnt_logprobs_smoothed( ...@@ -1036,7 +1252,7 @@ def get_rnnt_logprobs_smoothed(
-1 -1
) # [B][S][T] ) # [B][S][T]
if not modified: if rnnt_type == "regular":
px_am = torch.cat( px_am = torch.cat(
( (
px_am, px_am,
...@@ -1095,8 +1311,10 @@ def get_rnnt_logprobs_smoothed( ...@@ -1095,8 +1311,10 @@ def get_rnnt_logprobs_smoothed(
+ py_amonly * am_only_scale + py_amonly * am_only_scale
) )
if not modified: if rnnt_type == "regular":
px_interp = fix_for_boundary(px_interp, boundary) px_interp = fix_for_boundary(px_interp, boundary)
elif rnnt_type == "constrained":
px_interp += py_interp[:, 1:, :]
return (px_interp, py_interp) return (px_interp, py_interp)
...@@ -1109,7 +1327,8 @@ def rnnt_loss_smoothed( ...@@ -1109,7 +1327,8 @@ def rnnt_loss_smoothed(
lm_only_scale: float = 0.1, lm_only_scale: float = 0.1,
am_only_scale: float = 0.1, am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None, boundary: Optional[Tensor] = None,
modified: bool = False, rnnt_type: str = "regular",
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean", reduction: Optional[str] = "mean",
return_grad: bool = False, return_grad: bool = False,
) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: ) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]:
...@@ -1141,8 +1360,23 @@ def rnnt_loss_smoothed( ...@@ -1141,8 +1360,23 @@ def rnnt_loss_smoothed(
[0, 0, S, T] [0, 0, S, T]
if boundary is not supplied. if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero. Most likely you will want begin_symbol and begin_frame to be zero.
modified: if True, each time a real symbol is consumed a frame will rnnt_type:
also be consumed, so at most 1 symbol can appear per frame. Specifies the type of rnnt paths: `regular`, `modified` or `constrained`.
`regular`: The regular rnnt that taking you to the next frame only if
emitting a blank (i.e., emitting a symbol does not take you
to the next frame).
`modified`: A modified version of rnnt that will take you to the next
frame either emitting a blank or a non-blank symbol.
`constrained`: A version likes the modified one that will go to the next
frame when you emit a non-blank symbol, but this is done
by "forcing" you to take the blank transition from the
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
delay_penalty: A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time-masking
encouraging the network to delay symbols.
See https://github.com/k2-fsa/k2/issues/955 for more details.
reduction: reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied. `none`: no reduction will be applied.
...@@ -1173,8 +1407,24 @@ def rnnt_loss_smoothed( ...@@ -1173,8 +1407,24 @@ def rnnt_loss_smoothed(
lm_only_scale=lm_only_scale, lm_only_scale=lm_only_scale,
am_only_scale=am_only_scale, am_only_scale=am_only_scale,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
if delay_penalty > 0.0:
B, S, T0 = px.shape
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)
scores_and_grads = mutual_information_recursion( scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad px=px, py=py, boundary=boundary, return_grad=return_grad
) )
...@@ -1186,7 +1436,7 @@ def rnnt_loss_smoothed( ...@@ -1186,7 +1436,7 @@ def rnnt_loss_smoothed(
elif reduction == "sum": elif reduction == "sum":
loss = -torch.sum(negated_loss) loss = -torch.sum(negated_loss)
else: else:
assert ( raise ValueError(
False f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" )
return (loss, scores_and_grads[1]) if return_grad else loss return (loss, scores_and_grads[1]) if return_grad else loss
...@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase): ...@@ -90,7 +90,9 @@ class TestRnntLoss(unittest.TestCase):
assert px.shape == (B, S, T + 1) assert px.shape == (B, S, T + 1)
assert py.shape == (B, S + 1, T) assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S) assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(px=px, py=py, boundary=None) m = fast_rnnt.mutual_information_recursion(
px=px, py=py, boundary=None
)
if device == torch.device("cpu"): if device == torch.device("cpu"):
expected = -m expected = -m
...@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -205,7 +207,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length boundary_[:, 2] = seq_length
boundary_[:, 3] = frames boundary_[:, 3] = frames
for modified in [True, False]: for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices: for device in self.devices:
# lm: [B][S+1][C] # lm: [B][S+1][C]
lm = lm_.to(device) lm = lm_.to(device)
...@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase): ...@@ -220,9 +222,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
)
assert (
px.shape == (B, S, T)
if rnnt_type != "regular"
else (B, S, T + 1)
) )
assert px.shape == (B, S, T) if modified else (B, S, T + 1)
assert py.shape == (B, S + 1, T) assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S) assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion( m = fast_rnnt.mutual_information_recursion(
...@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -239,7 +245,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -251,7 +257,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0, lm_only_scale=0.0,
am_only_scale=0.0, am_only_scale=0.0,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase): ...@@ -261,12 +267,12 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
# compare with torchaudio rnnt_loss # compare with torchaudio rnnt_loss
if self.has_torch_rnnt_loss and not modified: if self.has_torch_rnnt_loss and rnnt_type == "regular":
import torchaudio.functional import torchaudio.functional
m = torchaudio.functional.rnnt_loss( m = torchaudio.functional.rnnt_loss(
...@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -288,7 +294,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -298,7 +304,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -310,7 +316,7 @@ class TestRnntLoss(unittest.TestCase):
lm_only_scale=0.0, lm_only_scale=0.0,
am_only_scale=0.0, am_only_scale=0.0,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
assert torch.allclose(m, expected.to(device)) assert torch.allclose(m, expected.to(device))
...@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase): ...@@ -368,9 +374,13 @@ class TestRnntLoss(unittest.TestCase):
torch_grad = torch.autograd.grad(torch_loss, logits2) torch_grad = torch.autograd.grad(torch_loss, logits2)
torch_grad = torch_grad[0] torch_grad = torch_grad[0]
assert torch.allclose(fast_loss, torch_loss, atol=1e-2, rtol=1e-2) assert torch.allclose(
fast_loss, torch_loss, atol=1e-2, rtol=1e-2
)
assert torch.allclose(fast_grad, torch_grad, atol=1e-2, rtol=1e-2) assert torch.allclose(
fast_grad, torch_grad, atol=1e-2, rtol=1e-2
)
def test_rnnt_loss_smoothed(self): def test_rnnt_loss_smoothed(self):
B = 1 B = 1
...@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -443,7 +453,7 @@ class TestRnntLoss(unittest.TestCase):
boundary_[:, 2] = seq_length boundary_[:, 2] = seq_length
boundary_[:, 3] = frames boundary_[:, 3] = frames
for modified in [True, False]: for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices: for device in self.devices:
# normal rnnt # normal rnnt
am = am_.to(device) am = am_.to(device)
...@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase): ...@@ -460,12 +470,10 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
) )
print( print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {fast_loss}")
f"Unpruned rnnt loss with modified {modified} : {fast_loss}"
)
# pruning # pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
...@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -474,7 +482,7 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
return_grad=True, return_grad=True,
reduction="none", reduction="none",
) )
...@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase): ...@@ -487,7 +495,9 @@ class TestRnntLoss(unittest.TestCase):
s_range=r, s_range=r,
) )
# (B, T, r, C) # (B, T, r, C)
pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges) pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)
logits = pruned_am + pruned_lm logits = pruned_am + pruned_lm
...@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -500,12 +510,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges, ranges=ranges,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
reduction="none", reduction="none",
) )
print(f"Pruning loss with range {r} : {pruned_loss}") print(f"Pruning loss with range {r} : {pruned_loss}")
# Test the sequences that only have small number of symbols, # Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will # at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions. # raise errors (like, nan or inf loss) in our previous versions.
...@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase): ...@@ -531,7 +540,7 @@ class TestRnntLoss(unittest.TestCase):
print(f"B = {B}, T = {T}, S = {S}, C = {C}") print(f"B = {B}, T = {T}, S = {S}, C = {C}")
for modified in [True, False]: for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices: for device in self.devices:
# normal rnnt # normal rnnt
am = am_.to(device) am = am_.to(device)
...@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -550,13 +559,11 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
reduction="none", reduction="none",
) )
print( print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}")
f"Unpruned rnnt loss with modified {modified} : {loss}"
)
# pruning # pruning
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
...@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase): ...@@ -565,13 +572,13 @@ class TestRnntLoss(unittest.TestCase):
symbols=symbols, symbols=symbols,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
return_grad=True, return_grad=True,
reduction="none", reduction="none",
) )
S0 = 2 S0 = 2
if modified: if rnnt_type != "regular":
S0 = 1 S0 = 1
for r in range(S0, S + 2): for r in range(S0, S + 2):
...@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase): ...@@ -597,10 +604,11 @@ class TestRnntLoss(unittest.TestCase):
ranges=ranges, ranges=ranges,
termination_symbol=terminal_symbol, termination_symbol=terminal_symbol,
boundary=boundary, boundary=boundary,
modified=modified, rnnt_type=rnnt_type,
reduction="none", reduction="none",
) )
print(f"Pruned loss with range {r} : {pruned_loss}") print(f"Pruned loss with range {r} : {pruned_loss}")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -36,11 +36,10 @@ class BuildExtension(build_ext): ...@@ -36,11 +36,10 @@ class BuildExtension(build_ext):
system_make_args = os.environ.get("MAKEFLAGS", "") system_make_args = os.environ.get("MAKEFLAGS", "")
if cmake_args == "": if cmake_args == "":
cmake_args = "-DCMAKE_BUILD_TYPE=Release" cmake_args = "-DCMAKE_BUILD_TYPE=Release -DFT_BUILD_TESTS=OFF"
if make_args == "" and system_make_args == "": if make_args == "" and system_make_args == "":
print("For fast compilation, run:") make_args = ' -j '
print('export FT_MAKE_ARGS="-j"; python setup.py install')
if "PYTHON_EXECUTABLE" not in cmake_args: if "PYTHON_EXECUTABLE" not in cmake_args:
print(f"Setting PYTHON_EXECUTABLE to {sys.executable}") print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment