Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
b32f8a26
Commit
b32f8a26
authored
Oct 30, 2022
by
pkufool
Browse files
Replace monotonic_lower_bound with efficient implementation; remove cub dependancy
parent
b0ed23ef
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
49 additions
and
320 deletions
+49
-320
CMakeLists.txt
CMakeLists.txt
+0
-4
cmake/cub.cmake
cmake/cub.cmake
+0
-42
fast_rnnt/csrc/CMakeLists.txt
fast_rnnt/csrc/CMakeLists.txt
+0
-8
fast_rnnt/csrc/utils.cu
fast_rnnt/csrc/utils.cu
+0
-63
fast_rnnt/csrc/utils.h
fast_rnnt/csrc/utils.h
+0
-101
fast_rnnt/python/csrc/CMakeLists.txt
fast_rnnt/python/csrc/CMakeLists.txt
+0
-2
fast_rnnt/python/csrc/fast_rnnt.cu
fast_rnnt/python/csrc/fast_rnnt.cu
+0
-2
fast_rnnt/python/csrc/mutual_information.cu
fast_rnnt/python/csrc/mutual_information.cu
+8
-0
fast_rnnt/python/csrc/utils.cu
fast_rnnt/python/csrc/utils.cu
+0
-56
fast_rnnt/python/csrc/utils.h
fast_rnnt/python/csrc/utils.h
+0
-32
fast_rnnt/python/fast_rnnt/__init__.py
fast_rnnt/python/fast_rnnt/__init__.py
+0
-1
fast_rnnt/python/fast_rnnt/rnnt_loss.py
fast_rnnt/python/fast_rnnt/rnnt_loss.py
+39
-7
fast_rnnt/python/tests/rnnt_loss_test.py
fast_rnnt/python/tests/rnnt_loss_test.py
+2
-2
No files found.
CMakeLists.txt
View file @
b32f8a26
...
@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
...
@@ -133,10 +133,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include
(
pybind11
)
include
(
pybind11
)
include
(
torch
)
include
(
torch
)
if
(
FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0
)
# CUB is included in CUDA toolkit 11.0 and above
include
(
cub
)
endif
()
if
(
FT_BUILD_TESTS
)
if
(
FT_BUILD_TESTS
)
enable_testing
()
enable_testing
()
...
...
cmake/cub.cmake
deleted
100644 → 0
View file @
b0ed23ef
# Copyright 2020 Fangjun Kuang (csukuangfj@gmail.com)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function
(
download_cub
)
if
(
CMAKE_VERSION VERSION_LESS 3.11
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_SOURCE_DIR
}
/cmake/Modules
)
endif
()
include
(
FetchContent
)
set
(
cub_URL
"https://github.com/NVlabs/cub/archive/1.15.0.tar.gz"
)
set
(
cub_HASH
"SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685"
)
FetchContent_Declare
(
cub
URL
${
cub_URL
}
URL_HASH
${
cub_HASH
}
)
FetchContent_GetProperties
(
cub
)
if
(
NOT cub
)
message
(
STATUS
"Downloading cub"
)
FetchContent_Populate
(
cub
)
endif
()
message
(
STATUS
"cub is downloaded to
${
cub_SOURCE_DIR
}
"
)
add_library
(
cub INTERFACE
)
target_include_directories
(
cub INTERFACE
${
cub_SOURCE_DIR
}
)
endfunction
()
download_cub
()
fast_rnnt/csrc/CMakeLists.txt
View file @
b32f8a26
...
@@ -5,7 +5,6 @@ include(transform)
...
@@ -5,7 +5,6 @@ include(transform)
set
(
srcs
set
(
srcs
mutual_information_cpu.cu
mutual_information_cpu.cu
utils.cu
)
)
if
(
NOT FT_WITH_CUDA
)
if
(
NOT FT_WITH_CUDA
)
...
@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs})
...
@@ -18,10 +17,3 @@ add_library(mutual_information_core ${srcs})
target_link_libraries
(
mutual_information_core PUBLIC
${
TORCH_LIBRARIES
}
)
target_link_libraries
(
mutual_information_core PUBLIC
${
TORCH_LIBRARIES
}
)
# for <torch/extension.h>
# for <torch/extension.h>
target_include_directories
(
mutual_information_core PUBLIC
${
PYTHON_INCLUDE_DIRS
}
)
target_include_directories
(
mutual_information_core PUBLIC
${
PYTHON_INCLUDE_DIRS
}
)
# see https://github.com/NVIDIA/thrust/issues/1401
# and https://github.com/k2-fsa/k2/pull/917
target_compile_definitions
(
mutual_information_core PUBLIC CUB_WRAPPED_NAMESPACE=fast_rnnt
)
target_compile_definitions
(
mutual_information_core PUBLIC THRUST_NS_QUALIFIER=thrust
)
if
(
FT_WITH_CUDA AND CUDA_VERSION VERSION_LESS 11.0
)
target_link_libraries
(
mutual_information_core PUBLIC cub
)
endif
()
fast_rnnt/csrc/utils.cu
deleted
100644 → 0
View file @
b0ed23ef
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/utils.h"
namespace
fast_rnnt
{
void
MonotonicLowerBound
(
torch
::
Tensor
&
src
)
{
TORCH_CHECK
(
src
.
dim
()
==
1
,
"Only support one dimension tensor"
);
TORCH_CHECK
(
src
.
scalar_type
()
==
torch
::
kLong
,
"Only support LongTensor"
);
TORCH_CHECK
(
src
.
is_contiguous
(),
"Expected to be contiguous"
);
int32_t
dim
=
src
.
numel
();
if
(
src
.
device
().
type
()
==
torch
::
kCPU
)
{
int64_t
min_value
=
std
::
numeric_limits
<
int64_t
>::
max
();
int64_t
*
src_data
=
src
.
data_ptr
<
int64_t
>
();
for
(
int32_t
i
=
dim
-
1
;
i
>=
0
;
--
i
)
{
min_value
=
std
::
min
(
src_data
[
i
],
min_value
);
src
[
i
]
=
min_value
;
}
}
else
{
#ifdef FT_WITH_CUDA
TORCH_CHECK
(
src
.
device
().
is_cuda
());
internal
::
MinOp
<
int64_t
>
min_op
;
auto
src_data
=
src
.
data_ptr
<
int64_t
>
();
internal
::
ConstReversedPtr
<
int64_t
>
src_ptr
=
internal
::
ConstReversedPtr
<
int64_t
>
(
src_data
,
dim
);
internal
::
ReversedPtr
<
int64_t
>
dest_ptr
=
internal
::
ReversedPtr
<
int64_t
>
(
src_data
,
dim
);
// The first time is to determine temporary device storage requirements.
std
::
size_t
temp_storage_bytes
=
0
;
auto
s
=
cub
::
DeviceScan
::
InclusiveScan
(
nullptr
,
temp_storage_bytes
,
src_ptr
,
dest_ptr
,
min_op
,
dim
);
TORCH_CHECK
(
s
==
cudaSuccess
,
cudaGetErrorString
(
s
));
auto
d_temp
=
torch
::
empty
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)},
torch
::
dtype
(
torch
::
kInt8
).
device
(
src
.
device
()));
int8_t
*
d_temp_storage
=
d_temp
.
data_ptr
<
int8_t
>
();
s
=
cub
::
DeviceScan
::
InclusiveScan
(
d_temp_storage
,
temp_storage_bytes
,
src_ptr
,
dest_ptr
,
min_op
,
dim
);
TORCH_CHECK
(
s
==
cudaSuccess
,
cudaGetErrorString
(
s
));
#else
TORCH_CHECK
(
false
,
"Please build with -DFT_WITH_CUDA=ON"
);
#endif // FT_WITH_CUDA
}
}
}
// namespace fast_rnnt
fast_rnnt/csrc/utils.h
deleted
100644 → 0
View file @
b0ed23ef
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_UTILS_H_
#define FAST_RNNT_CSRC_UTILS_H_
#include "torch/extension.h"
#ifdef FT_WITH_CUDA
#include "cub/cub.cuh" // NOLINT
namespace
fast_rnnt
{
namespace
internal
{
template
<
typename
T
>
struct
MinOp
{
__host__
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
(
a
>
b
)
?
b
:
a
;
}
};
// Will be used (as both InputIterator and OutputIterator) in
// MonotonicLowerBound to call cub::DeviceScan::InclusiveScan.
template
<
typename
T
>
struct
ConstReversedPtr
{
const
T
*
data
;
// data points to the last element now
explicit
ConstReversedPtr
(
const
T
*
data
,
int32_t
size
)
:
data
(
data
+
size
-
1
)
{}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__
__device__
__forceinline__
const
T
&
operator
[](
int32_t
i
)
const
{
return
data
[
-
i
];
}
__host__
__device__
__forceinline__
ConstReversedPtr
operator
+
(
int32_t
n
)
const
{
ConstReversedPtr
tmp
(
*
this
);
tmp
.
data
-=
n
;
return
tmp
;
}
__host__
__device__
__forceinline__
const
T
&
operator
*
()
const
{
return
*
data
;
}
};
template
<
typename
T
>
struct
ReversedPtr
{
T
*
data
;
// data points to the last element now
explicit
ReversedPtr
(
T
*
data
,
int32_t
size
)
:
data
(
data
+
size
-
1
)
{}
// operator[], operator+, and operator* are required by
// cub::DeviceScan::InclusiveScan
__host__
__device__
__forceinline__
T
&
operator
[](
int32_t
i
)
{
return
data
[
-
i
];
}
__host__
__device__
__forceinline__
ReversedPtr
operator
+
(
int32_t
n
)
const
{
ReversedPtr
tmp
(
*
this
);
tmp
.
data
-=
n
;
return
tmp
;
}
__host__
__device__
__forceinline__
T
&
operator
*
()
{
return
*
data
;
}
};
}
// namespace internal
}
// namespace fast_rnnt
namespace
std
{
// vaule_type is required by cub::DeviceScan::InclusiveSum
template
<
typename
T
>
struct
iterator_traits
<
fast_rnnt
::
internal
::
ConstReversedPtr
<
T
>>
{
typedef
T
value_type
;
};
template
<
typename
T
>
struct
iterator_traits
<
fast_rnnt
::
internal
::
ReversedPtr
<
T
>>
{
typedef
T
value_type
;
};
}
// namespace std
#endif // FT_WITH_CUDA
namespace
fast_rnnt
{
void
MonotonicLowerBound
(
torch
::
Tensor
&
src
);
}
// namespace fast_rnnt
#endif // FAST_RNNT_CSRC_UTILS_H_
fast_rnnt/python/csrc/CMakeLists.txt
View file @
b32f8a26
...
@@ -6,7 +6,6 @@ include(transform)
...
@@ -6,7 +6,6 @@ include(transform)
set
(
fast_rnnt_srcs
set
(
fast_rnnt_srcs
fast_rnnt.cu
fast_rnnt.cu
mutual_information.cu
mutual_information.cu
utils.cu
)
)
if
(
NOT FT_WITH_CUDA
)
if
(
NOT FT_WITH_CUDA
)
...
@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA)
...
@@ -14,7 +13,6 @@ if(NOT FT_WITH_CUDA)
endif
()
endif
()
pybind11_add_module
(
_fast_rnnt
${
fast_rnnt_srcs
}
)
pybind11_add_module
(
_fast_rnnt
${
fast_rnnt_srcs
}
)
target_link_libraries
(
_fast_rnnt PRIVATE mutual_information_core
)
target_link_libraries
(
_fast_rnnt PRIVATE mutual_information_core
)
...
...
fast_rnnt/python/csrc/fast_rnnt.cu
View file @
b32f8a26
...
@@ -20,11 +20,9 @@
...
@@ -20,11 +20,9 @@
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/fast_rnnt.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/utils.h"
PYBIND11_MODULE
(
_fast_rnnt
,
m
)
{
PYBIND11_MODULE
(
_fast_rnnt
,
m
)
{
m
.
doc
()
=
"Python wrapper for Fast Rnnt."
;
m
.
doc
()
=
"Python wrapper for Fast Rnnt."
;
fast_rnnt
::
PybindMutualInformation
(
m
);
fast_rnnt
::
PybindMutualInformation
(
m
);
fast_rnnt
::
PybindUtils
(
m
);
}
}
fast_rnnt/python/csrc/mutual_information.cu
View file @
b32f8a26
...
@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) {
...
@@ -65,5 +65,13 @@ void PybindMutualInformation(py::module &m) {
},
},
py
::
arg
(
"px"
),
py
::
arg
(
"py"
),
py
::
arg
(
"boundary"
),
py
::
arg
(
"p"
),
py
::
arg
(
"px"
),
py
::
arg
(
"py"
),
py
::
arg
(
"boundary"
),
py
::
arg
(
"p"
),
py
::
arg
(
"ans_grad"
));
py
::
arg
(
"ans_grad"
));
m
.
def
(
"with_cuda"
,
[]()
->
bool
{
#ifdef FT_WITH_CUDA
return
true
;
#else
return
false
;
#endif
});
}
}
}
// namespace fast_rnnt
}
// namespace fast_rnnt
fast_rnnt/python/csrc/utils.cu
deleted
100644 → 0
View file @
b0ed23ef
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/utils.h"
#include "fast_rnnt/python/csrc/utils.h"
namespace
fast_rnnt
{
void
PybindUtils
(
py
::
module
&
m
)
{
m
.
def
(
"monotonic_lower_bound_"
,
[](
torch
::
Tensor
&
src
)
->
void
{
DeviceGuard
guard
(
src
.
device
());
if
(
src
.
dim
()
==
1
)
{
MonotonicLowerBound
(
src
);
}
else
if
(
src
.
dim
()
==
2
)
{
int32_t
dim0
=
src
.
sizes
()[
0
];
for
(
int32_t
i
=
0
;
i
<
dim0
;
++
i
)
{
auto
sub
=
src
.
index
({
i
});
MonotonicLowerBound
(
sub
);
}
}
else
{
TORCH_CHECK
(
false
,
"Only support 1 dimension and 2 dimensions tensor"
);
}
},
py
::
arg
(
"src"
));
m
.
def
(
"with_cuda"
,
[]()
->
bool
{
#ifdef FT_WITH_CUDA
return
true
;
#else
return
false
;
#endif
});
}
}
// namespace fast_rnnt
fast_rnnt/python/csrc/utils.h
deleted
100644 → 0
View file @
b0ed23ef
/**
* @brief python wrappers for utils.h
*
* @copyright
* Copyright 2022 Xiaomi Corp. (author: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_PYTHON_CSRC_UTILS_H_
#define FAST_RNNT_PYTHON_CSRC_UTILS_H_
#include "fast_rnnt/python/csrc/fast_rnnt.h"
namespace
fast_rnnt
{
void
PybindUtils
(
py
::
module
&
m
);
}
// namespace fast_rnnt
#endif // FAST_RNNT_PYTHON_CSRC_UTILS_H_
fast_rnnt/python/fast_rnnt/__init__.py
View file @
b32f8a26
from
_fast_rnnt
import
monotonic_lower_bound_
from
_fast_rnnt
import
with_cuda
from
_fast_rnnt
import
with_cuda
from
.mutual_information
import
mutual_information_recursion
from
.mutual_information
import
mutual_information_recursion
...
...
fast_rnnt/python/fast_rnnt/rnnt_loss.py
View file @
b32f8a26
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
import
os
import
os
import
fast_rnnt
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
...
@@ -521,6 +520,40 @@ def rnnt_loss(
...
@@ -521,6 +520,40 @@ def rnnt_loss(
)
)
def
_monotonic_lower_bound
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute a monotonically increasing lower bound of the tensor `x` on the
last dimension. The basic idea is: we traverse the tensor in reverse order,
and update current element with the following statement,
min_value = min(x[i], min_value)
x[i] = min_value
>>> import torch
>>> x = torch.tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32)
>>> _monotonic_lower_bound(x)
tensor([0, 1, 1, 3, 5, 5, 8], dtype=torch.int32)
>>> x
tensor([0, 2, 1, 3, 6, 5, 8], dtype=torch.int32)
>>> x = torch.randint(20, (3, 6), dtype=torch.int32)
>>> x
tensor([[12, 18, 5, 4, 18, 17],
[11, 14, 14, 3, 10, 4],
[19, 3, 8, 13, 7, 19]], dtype=torch.int32)
>>> _monotonic_lower_bound(x)
tensor([[ 4, 4, 4, 4, 17, 17],
[ 3, 3, 3, 3, 4, 4],
[ 3, 3, 7, 7, 7, 19]], dtype=torch.int32)
Args:
x:
The source tensor.
Returns:
Returns a tensor which is monotonic on the last dimension
(i.e. satisfiy `x[i] <= x[i+1]`).
"""
x
=
torch
.
flip
(
x
,
dims
=
(
-
1
,))
x
,
_
=
torch
.
cummin
(
x
,
dim
=-
1
)
x
=
torch
.
flip
(
x
,
dims
=
(
-
1
,))
return
x
def
_adjust_pruning_lower_bound
(
def
_adjust_pruning_lower_bound
(
s_begin
:
torch
.
Tensor
,
s_range
:
int
s_begin
:
torch
.
Tensor
,
s_range
:
int
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -532,11 +565,10 @@ def _adjust_pruning_lower_bound(
...
@@ -532,11 +565,10 @@ def _adjust_pruning_lower_bound(
- s_begin[i + 1] - s_begin[i] < s_range, which means that we can't skip
- s_begin[i + 1] - s_begin[i] < s_range, which means that we can't skip
any symbols.
any symbols.
To make it monotonic increasing, we can use `monotonic_lower_bound`
function
To make it monotonic increasing, we can use `
_
monotonic_lower_bound`
above,
in k2,
which guarantees `s_begin[i] <= s_begin[i + 1]`. The main idea is:
which guarantees `s_begin[i] <= s_begin[i + 1]`. The main idea is:
traverse the array in reverse order and update the elements by
traverse the array in reverse order and update the elements by
`min_value = min(a_begin[i], min_value)`, the initial `min_value` is set to
`min_value = min(a_begin[i], min_value)`.
`inf`.
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
The method we used to realize `s_begin[i + 1] - s_begin[i] < s_range`
constraint is a little tricky. We first transform `s_begin` with
constraint is a little tricky. We first transform `s_begin` with
...
@@ -559,13 +591,13 @@ def _adjust_pruning_lower_bound(
...
@@ -559,13 +591,13 @@ def _adjust_pruning_lower_bound(
"""
"""
# s_begin (B, T)
# s_begin (B, T)
(
B
,
T
)
=
s_begin
.
shape
(
B
,
T
)
=
s_begin
.
shape
fast_rnnt
.
monotonic_lower_bound
_
(
s_begin
)
_
monotonic_lower_bound
(
s_begin
)
# do the magic transformation
# do the magic transformation
s_begin
=
-
(
s_begin
=
-
(
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
s_begin
-
(
s_range
-
1
)
*
torch
.
arange
(
0
,
T
,
device
=
s_begin
.
device
)
)
)
# make the transformed tensor to be non-decreasing
# make the transformed tensor to be non-decreasing
fast_rnnt
.
monotonic_lower_bound
_
(
s_begin
)
_
monotonic_lower_bound
(
s_begin
)
# make start symbol to be zero.
# make start symbol to be zero.
s_begin
=
torch
.
where
(
s_begin
<
0
,
0
,
s_begin
)
s_begin
=
torch
.
where
(
s_begin
<
0
,
0
,
s_begin
)
# do the magic transformation again to recover s_begin
# do the magic transformation again to recover s_begin
...
...
fast_rnnt/python/tests/rnnt_loss_test.py
View file @
b32f8a26
...
@@ -473,7 +473,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -473,7 +473,7 @@ class TestRnntLoss(unittest.TestCase):
rnnt_type
=
rnnt_type
,
rnnt_type
=
rnnt_type
,
)
)
print
(
f
"Unpruned rnnt loss with
{
rnnt_
loss
}
rnnt :
{
fast_loss
}
"
)
print
(
f
"Unpruned rnnt loss with
{
rnnt_
type
}
rnnt :
{
fast_loss
}
"
)
# pruning
# pruning
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
simple_loss
,
(
px_grad
,
py_grad
)
=
fast_rnnt
.
rnnt_loss_simple
(
...
@@ -578,7 +578,7 @@ class TestRnntLoss(unittest.TestCase):
...
@@ -578,7 +578,7 @@ class TestRnntLoss(unittest.TestCase):
)
)
S0
=
2
S0
=
2
if
rnnt_type
=
=
"regular"
:
if
rnnt_type
!
=
"regular"
:
S0
=
1
S0
=
1
for
r
in
range
(
S0
,
S
+
2
):
for
r
in
range
(
S0
,
S
+
2
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment