Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
6b07bcf8
Unverified
Commit
6b07bcf8
authored
Jan 05, 2021
by
Vincent QB
Committed by
GitHub
Jan 05, 2021
Browse files
Add RNN Transducer Loss for CPU (#1137)
parent
7d00504d
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
594 additions
and
8 deletions
+594
-8
.circleci/torchscript_bc_test/common.sh
.circleci/torchscript_bc_test/common.sh
+2
-1
.circleci/unittest/linux/scripts/install.sh
.circleci/unittest/linux/scripts/install.sh
+1
-1
.circleci/unittest/linux/scripts/setup_env.sh
.circleci/unittest/linux/scripts/setup_env.sh
+1
-0
.gitmodules
.gitmodules
+4
-0
build_tools/setup_helpers/extension.py
build_tools/setup_helpers/extension.py
+11
-4
packaging/build_wheel.sh
packaging/build_wheel.sh
+1
-1
packaging/pkg_helpers.bash
packaging/pkg_helpers.bash
+1
-0
packaging/torchaudio/build.sh
packaging/torchaudio/build.sh
+1
-1
test/torchaudio_unittest/transducer_test.py
test/torchaudio_unittest/transducer_test.py
+276
-0
third_party/CMakeLists.txt
third_party/CMakeLists.txt
+2
-0
third_party/transducer/CMakeLists.txt
third_party/transducer/CMakeLists.txt
+38
-0
third_party/transducer/submodule
third_party/transducer/submodule
+1
-0
torchaudio/csrc/register.cpp
torchaudio/csrc/register.cpp
+14
-0
torchaudio/csrc/transducer.cpp
torchaudio/csrc/transducer.cpp
+82
-0
torchaudio/prototype/__init__.py
torchaudio/prototype/__init__.py
+0
-0
torchaudio/prototype/transducer.py
torchaudio/prototype/transducer.py
+159
-0
No files found.
.circleci/torchscript_bc_test/common.sh
View file @
6b07bcf8
...
@@ -66,5 +66,6 @@ build_master() {
...
@@ -66,5 +66,6 @@ build_master() {
conda
install
-y
-q
pytorch
"cpuonly"
-c
pytorch-nightly
conda
install
-y
-q
pytorch
"cpuonly"
-c
pytorch-nightly
printf
"* Installing torchaudio
\n
"
printf
"* Installing torchaudio
\n
"
cd
"
${
_root_dir
}
"
||
exit
1
cd
"
${
_root_dir
}
"
||
exit
1
BUILD_SOX
=
1 python setup.py clean
install
git submodule update
--init
--recursive
BUILD_TRANSDUCER
=
1
BUILD_SOX
=
1 python setup.py clean
install
}
}
.circleci/unittest/linux/scripts/install.sh
View file @
6b07bcf8
...
@@ -38,7 +38,7 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit}
...
@@ -38,7 +38,7 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit}
# 2. Install torchaudio
# 2. Install torchaudio
printf
"* Installing torchaudio
\n
"
printf
"* Installing torchaudio
\n
"
BUILD_SOX
=
1 python setup.py
install
BUILD_TRANSDUCER
=
1
BUILD_SOX
=
1 python setup.py
install
# 3. Install Test tools
# 3. Install Test tools
printf
"* Installing test tools
\n
"
printf
"* Installing test tools
\n
"
...
...
.circleci/unittest/linux/scripts/setup_env.sh
View file @
6b07bcf8
...
@@ -43,6 +43,7 @@ conda activate "${env_dir}"
...
@@ -43,6 +43,7 @@ conda activate "${env_dir}"
pip
--quiet
install
cmake ninja
pip
--quiet
install
cmake ninja
# 4. Buld codecs
# 4. Buld codecs
git submodule update
--init
--recursive
mkdir
-p
third_party/build
mkdir
-p
third_party/build
(
(
cd
third_party/build
cd
third_party/build
...
...
.gitmodules
0 → 100644
View file @
6b07bcf8
[submodule "third_party/warp_transducer/submodule"]
path = third_party/transducer/submodule
url = https://github.com/HawkAaron/warp-transducer
ignore = dirty
build_tools/setup_helpers/extension.py
View file @
6b07bcf8
...
@@ -20,20 +20,21 @@ _TP_BASE_DIR = _ROOT_DIR / 'third_party'
...
@@ -20,20 +20,21 @@ _TP_BASE_DIR = _ROOT_DIR / 'third_party'
_TP_INSTALL_DIR
=
_TP_BASE_DIR
/
'install'
_TP_INSTALL_DIR
=
_TP_BASE_DIR
/
'install'
def
_get_build
_sox
(
):
def
_get_build
(
var
):
val
=
os
.
environ
.
get
(
'BUILD_SOX'
,
'0'
)
val
=
os
.
environ
.
get
(
var
,
'0'
)
trues
=
[
'1'
,
'true'
,
'TRUE'
,
'on'
,
'ON'
,
'yes'
,
'YES'
]
trues
=
[
'1'
,
'true'
,
'TRUE'
,
'on'
,
'ON'
,
'yes'
,
'YES'
]
falses
=
[
'0'
,
'false'
,
'FALSE'
,
'off'
,
'OFF'
,
'no'
,
'NO'
]
falses
=
[
'0'
,
'false'
,
'FALSE'
,
'off'
,
'OFF'
,
'no'
,
'NO'
]
if
val
in
trues
:
if
val
in
trues
:
return
True
return
True
if
val
not
in
falses
:
if
val
not
in
falses
:
print
(
print
(
f
'WARNING: Unexpected environment variable value `
BUILD_SOX
=
{
val
}
`. '
f
'WARNING: Unexpected environment variable value `
{
var
}
=
{
val
}
`. '
f
'Expected one of
{
trues
+
falses
}
'
)
f
'Expected one of
{
trues
+
falses
}
'
)
return
False
return
False
_BUILD_SOX
=
_get_build_sox
()
_BUILD_SOX
=
_get_build
(
"BUILD_SOX"
)
_BUILD_TRANSDUCER
=
_get_build
(
"BUILD_TRANSDUCER"
)
def
_get_eca
(
debug
):
def
_get_eca
(
debug
):
...
@@ -42,6 +43,8 @@ def _get_eca(debug):
...
@@ -42,6 +43,8 @@ def _get_eca(debug):
eca
+=
[
"-O0"
,
"-g"
]
eca
+=
[
"-O0"
,
"-g"
]
else
:
else
:
eca
+=
[
"-O3"
]
eca
+=
[
"-O3"
]
if
_BUILD_TRANSDUCER
:
eca
+=
[
'-DBUILD_TRANSDUCER'
]
return
eca
return
eca
...
@@ -67,6 +70,8 @@ def _get_include_dirs():
...
@@ -67,6 +70,8 @@ def _get_include_dirs():
]
]
if
_BUILD_SOX
:
if
_BUILD_SOX
:
dirs
.
append
(
str
(
_TP_INSTALL_DIR
/
'include'
))
dirs
.
append
(
str
(
_TP_INSTALL_DIR
/
'include'
))
if
_BUILD_TRANSDUCER
:
dirs
.
append
(
str
(
_TP_BASE_DIR
/
'transducer'
/
'submodule'
/
'include'
))
return
dirs
return
dirs
...
@@ -94,6 +99,8 @@ def _get_extra_objects():
...
@@ -94,6 +99,8 @@ def _get_extra_objects():
]
]
for
lib
in
libs
:
for
lib
in
libs
:
objs
.
append
(
str
(
_TP_INSTALL_DIR
/
'lib'
/
lib
))
objs
.
append
(
str
(
_TP_INSTALL_DIR
/
'lib'
/
lib
))
if
_BUILD_TRANSDUCER
:
objs
.
append
(
str
(
_TP_BASE_DIR
/
'build'
/
'transducer'
/
'libwarprnnt.a'
))
return
objs
return
objs
...
...
packaging/build_wheel.sh
View file @
6b07bcf8
...
@@ -15,5 +15,5 @@ if [[ "$OSTYPE" == "msys" ]]; then
...
@@ -15,5 +15,5 @@ if [[ "$OSTYPE" == "msys" ]]; then
python_tag
=
"
$(
echo
"cp
$PYTHON_VERSION
"
|
tr
-d
'.'
)
"
python_tag
=
"
$(
echo
"cp
$PYTHON_VERSION
"
|
tr
-d
'.'
)
"
python setup.py bdist_wheel
--plat-name
win_amd64
--python-tag
$python_tag
python setup.py bdist_wheel
--plat-name
win_amd64
--python-tag
$python_tag
else
else
BUILD_SOX
=
1 python setup.py bdist_wheel
BUILD_TRANSDUCER
=
1
BUILD_SOX
=
1 python setup.py bdist_wheel
fi
fi
packaging/pkg_helpers.bash
View file @
6b07bcf8
...
@@ -103,6 +103,7 @@ setup_macos() {
...
@@ -103,6 +103,7 @@ setup_macos() {
#
#
# Usage: setup_env 0.2.0
# Usage: setup_env 0.2.0
setup_env
()
{
setup_env
()
{
git submodule update
--init
--recursive
setup_cuda
setup_cuda
setup_build_version
"
$1
"
setup_build_version
"
$1
"
setup_macos
setup_macos
...
...
packaging/torchaudio/build.sh
View file @
6b07bcf8
#!/usr/bin/env bash
#!/usr/bin/env bash
set
-ex
set
-ex
BUILD_SOX
=
1 python setup.py
install
--single-version-externally-managed
--record
=
record.txt
BUILD_TRANSDUCER
=
1
BUILD_SOX
=
1 python setup.py
install
--single-version-externally-managed
--record
=
record.txt
test/torchaudio_unittest/transducer_test.py
0 → 100644
View file @
6b07bcf8
import
torch
from
torchaudio.prototype.transducer
import
RNNTLoss
from
torchaudio_unittest
import
common_utils
def
get_data_basic
(
device
):
# Example provided
# in 6f73a2513dc784c59eec153a45f40bc528355b18
# of https://github.com/HawkAaron/warp-transducer
acts
=
torch
.
tensor
(
[
[
[
[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.6
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.8
,
0.1
],
],
[
[
0.1
,
0.6
,
0.1
,
0.1
,
0.1
],
[
0.1
,
0.1
,
0.2
,
0.1
,
0.1
],
[
0.7
,
0.1
,
0.2
,
0.1
,
0.1
],
],
]
],
dtype
=
torch
.
float
,
)
labels
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int
)
act_length
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
)
label_length
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int
)
acts
=
acts
.
to
(
device
)
labels
=
labels
.
to
(
device
)
act_length
=
act_length
.
to
(
device
)
label_length
=
label_length
.
to
(
device
)
acts
.
requires_grad_
(
True
)
return
acts
,
labels
,
act_length
,
label_length
def
get_data_B2_T4_U3_D3
(
dtype
=
torch
.
float32
,
device
=
"cpu"
):
# Test from D21322854
logits
=
torch
.
tensor
(
[
0.065357
,
0.787530
,
0.081592
,
0.529716
,
0.750675
,
0.754135
,
0.609764
,
0.868140
,
0.622532
,
0.668522
,
0.858039
,
0.164539
,
0.989780
,
0.944298
,
0.603168
,
0.946783
,
0.666203
,
0.286882
,
0.094184
,
0.366674
,
0.736168
,
0.166680
,
0.714154
,
0.399400
,
0.535982
,
0.291821
,
0.612642
,
0.324241
,
0.800764
,
0.524106
,
0.779195
,
0.183314
,
0.113745
,
0.240222
,
0.339470
,
0.134160
,
0.505562
,
0.051597
,
0.640290
,
0.430733
,
0.829473
,
0.177467
,
0.320700
,
0.042883
,
0.302803
,
0.675178
,
0.569537
,
0.558474
,
0.083132
,
0.060165
,
0.107958
,
0.748615
,
0.943918
,
0.486356
,
0.418199
,
0.652408
,
0.024243
,
0.134582
,
0.366342
,
0.295830
,
0.923670
,
0.689929
,
0.741898
,
0.250005
,
0.603430
,
0.987289
,
0.592606
,
0.884672
,
0.543450
,
0.660770
,
0.377128
,
0.358021
,
],
dtype
=
dtype
,
).
reshape
(
2
,
4
,
3
,
3
)
targets
=
torch
.
tensor
([[
1
,
2
],
[
1
,
1
]],
dtype
=
torch
.
int32
)
src_lengths
=
torch
.
tensor
([
4
,
4
],
dtype
=
torch
.
int32
)
tgt_lengths
=
torch
.
tensor
([
2
,
2
],
dtype
=
torch
.
int32
)
blank
=
0
ref_costs
=
torch
.
tensor
([
4.2806528590890736
,
3.9384369822503591
],
dtype
=
dtype
)
ref_gradients
=
torch
.
tensor
(
[
-
0.186844
,
-
0.062555
,
0.249399
,
-
0.203377
,
0.202399
,
0.000977
,
-
0.141016
,
0.079123
,
0.061893
,
-
0.011552
,
-
0.081280
,
0.092832
,
-
0.154257
,
0.229433
,
-
0.075176
,
-
0.246593
,
0.146405
,
0.100188
,
-
0.012918
,
-
0.061593
,
0.074512
,
-
0.055986
,
0.219831
,
-
0.163845
,
-
0.497627
,
0.209240
,
0.288387
,
0.013605
,
-
0.030220
,
0.016615
,
0.113925
,
0.062781
,
-
0.176706
,
-
0.667078
,
0.367659
,
0.299419
,
-
0.356344
,
-
0.055347
,
0.411691
,
-
0.096922
,
0.029459
,
0.067463
,
-
0.063518
,
0.027654
,
0.035863
,
-
0.154499
,
-
0.073942
,
0.228441
,
-
0.166790
,
-
0.000088
,
0.166878
,
-
0.172370
,
0.105565
,
0.066804
,
0.023875
,
-
0.118256
,
0.094381
,
-
0.104707
,
-
0.108934
,
0.213642
,
-
0.369844
,
0.180118
,
0.189726
,
0.025714
,
-
0.079462
,
0.053748
,
0.122328
,
-
0.238789
,
0.116460
,
-
0.598687
,
0.302203
,
0.296484
,
],
dtype
=
dtype
,
).
reshape
(
2
,
4
,
3
,
3
)
logits
.
requires_grad_
(
True
)
logits
=
logits
.
to
(
device
)
def
grad_hook
(
grad
):
logits
.
saved_grad
=
grad
.
clone
()
logits
.
register_hook
(
grad_hook
)
data
=
{
"logits"
:
logits
,
"targets"
:
targets
,
"src_lengths"
:
src_lengths
,
"tgt_lengths"
:
tgt_lengths
,
"blank"
:
blank
,
}
return
data
,
ref_costs
,
ref_gradients
def
compute_with_pytorch_transducer
(
data
):
costs
=
RNNTLoss
(
blank
=
data
[
"blank"
],
reduction
=
"none"
)(
acts
=
data
[
"logits"
],
labels
=
data
[
"targets"
],
act_lens
=
data
[
"src_lengths"
],
label_lens
=
data
[
"tgt_lengths"
],
)
loss
=
torch
.
sum
(
costs
)
loss
.
backward
()
costs
=
costs
.
cpu
()
gradients
=
data
[
"logits"
].
saved_grad
.
cpu
()
return
costs
,
gradients
class
TransducerTester
:
def
test_basic_fp16_error
(
self
):
rnnt_loss
=
RNNTLoss
()
acts
,
labels
,
act_length
,
label_length
=
get_data_basic
(
self
.
device
)
acts
=
acts
.
to
(
torch
.
float16
)
# RuntimeError raised by log_softmax before reaching transducer's bindings
self
.
assertRaises
(
RuntimeError
,
rnnt_loss
,
acts
,
labels
,
act_length
,
label_length
)
def
test_basic_backward
(
self
):
rnnt_loss
=
RNNTLoss
()
acts
,
labels
,
act_length
,
label_length
=
get_data_basic
(
self
.
device
)
loss
=
rnnt_loss
(
acts
,
labels
,
act_length
,
label_length
)
loss
.
backward
()
def
test_costs_and_gradients_B2_T4_U3_D3_fp32
(
self
):
data
,
ref_costs
,
ref_gradients
=
get_data_B2_T4_U3_D3
(
dtype
=
torch
.
float32
,
device
=
self
.
device
)
logits_shape
=
data
[
"logits"
].
shape
costs
,
gradients
=
compute_with_pytorch_transducer
(
data
=
data
)
atol
,
rtol
=
1e-6
,
1e-2
self
.
assertEqual
(
costs
,
ref_costs
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
logits_shape
,
gradients
.
shape
)
self
.
assertEqual
(
gradients
,
ref_gradients
,
atol
=
atol
,
rtol
=
rtol
)
@
common_utils
.
skipIfNoExtension
class
CPUTransducerTester
(
TransducerTester
,
common_utils
.
PytorchTestCase
):
device
=
"cpu"
third_party/CMakeLists.txt
View file @
6b07bcf8
...
@@ -88,3 +88,5 @@ ExternalProject_Add(libsox
...
@@ -88,3 +88,5 @@ ExternalProject_Add(libsox
# See https://github.com/pytorch/audio/pull/1026
# See https://github.com/pytorch/audio/pull/1026
CONFIGURE_COMMAND
${
CMAKE_CURRENT_SOURCE_DIR
}
/build_codec_helper.sh
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/libsox/configure
${
COMMON_ARGS
}
--with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp --without-sndio --without-pulseaudio
CONFIGURE_COMMAND
${
CMAKE_CURRENT_SOURCE_DIR
}
/build_codec_helper.sh
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/libsox/configure
${
COMMON_ARGS
}
--with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp --without-sndio --without-pulseaudio
)
)
add_subdirectory
(
transducer
)
third_party/transducer/CMakeLists.txt
0 → 100755
View file @
6b07bcf8
CMAKE_MINIMUM_REQUIRED
(
VERSION 3.5
)
PROJECT
(
rnnt_release
)
SET
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-O2"
)
IF
(
APPLE
)
ADD_DEFINITIONS
(
-DAPPLE
)
ENDIF
()
INCLUDE_DIRECTORIES
(
submodule/include
)
SET
(
CMAKE_POSITION_INDEPENDENT_CODE ON
)
ADD_DEFINITIONS
(
-DRNNT_DISABLE_OMP
)
IF
(
APPLE
)
EXEC_PROGRAM
(
uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION
)
STRING
(
REGEX MATCH
"[0-9]+"
DARWIN_VERSION
${
DARWIN_VERSION
}
)
MESSAGE
(
STATUS
"DARWIN_VERSION=
${
DARWIN_VERSION
}
"
)
# for el capitain have to use rpath
IF
(
DARWIN_VERSION LESS 15
)
SET
(
CMAKE_SKIP_RPATH TRUE
)
ENDIF
()
ELSE
()
# always skip for linux
SET
(
CMAKE_SKIP_RPATH TRUE
)
ENDIF
()
ADD_LIBRARY
(
warprnnt STATIC submodule/src/rnnt_entrypoint.cpp
)
INSTALL
(
TARGETS warprnnt
LIBRARY DESTINATION
"lib"
ARCHIVE DESTINATION
"lib"
)
INSTALL
(
FILES submodule/include/rnnt.h DESTINATION
"submodule/include"
)
submodule
@
f5465751
Subproject commit f546575109111c455354861a0567c8aa794208a2
torchaudio/csrc/register.cpp
View file @
6b07bcf8
...
@@ -77,5 +77,19 @@ TORCH_LIBRARY(torchaudio, m) {
...
@@ -77,5 +77,19 @@ TORCH_LIBRARY(torchaudio, m) {
m
.
def
(
m
.
def
(
"torchaudio::sox_effects_apply_effects_file"
,
"torchaudio::sox_effects_apply_effects_file"
,
&
torchaudio
::
sox_effects
::
apply_effects_file
);
&
torchaudio
::
sox_effects
::
apply_effects_file
);
//////////////////////////////////////////////////////////////////////////////
// transducer.cpp
//////////////////////////////////////////////////////////////////////////////
#ifdef BUILD_TRANSDUCER
m
.
def
(
"rnnt_loss(Tensor acts,"
"Tensor labels,"
"Tensor input_lengths,"
"Tensor label_lengths,"
"Tensor costs,"
"Tensor grads,"
"int blank_label,"
"int num_threads) -> int"
);
#endif
}
}
#endif
#endif
torchaudio/csrc/transducer.cpp
0 → 100644
View file @
6b07bcf8
#ifdef BUILD_TRANSDUCER
#include <iostream>
#include <numeric>
#include <string>
#include <vector>
#include <torch/script.h>
#include "rnnt.h"
int64_t
cpu_rnnt_loss
(
torch
::
Tensor
acts
,
torch
::
Tensor
labels
,
torch
::
Tensor
input_lengths
,
torch
::
Tensor
label_lengths
,
torch
::
Tensor
costs
,
torch
::
Tensor
grads
,
int64_t
blank_label
,
int64_t
num_threads
)
{
int
maxT
=
acts
.
size
(
1
);
int
maxU
=
acts
.
size
(
2
);
int
minibatch_size
=
acts
.
size
(
0
);
int
alphabet_size
=
acts
.
size
(
3
);
rnntOptions
options
;
memset
(
&
options
,
0
,
sizeof
(
options
));
options
.
maxT
=
maxT
;
options
.
maxU
=
maxU
;
options
.
blank_label
=
blank_label
;
options
.
batch_first
=
true
;
options
.
loc
=
RNNT_CPU
;
options
.
num_threads
=
num_threads
;
// have to use at least one
options
.
num_threads
=
std
::
max
(
options
.
num_threads
,
(
unsigned
int
)
1
);
size_t
cpu_size_bytes
=
0
;
switch
(
acts
.
scalar_type
())
{
case
torch
::
ScalarType
::
Float
:
{
get_workspace_size
(
maxT
,
maxU
,
minibatch_size
,
false
,
&
cpu_size_bytes
);
std
::
vector
<
float
>
cpu_workspace
(
cpu_size_bytes
/
sizeof
(
float
),
0
);
compute_rnnt_loss
(
acts
.
data
<
float
>
(),
grads
.
data
<
float
>
(),
labels
.
data
<
int
>
(),
label_lengths
.
data
<
int
>
(),
input_lengths
.
data
<
int
>
(),
alphabet_size
,
minibatch_size
,
costs
.
data
<
float
>
(),
cpu_workspace
.
data
(),
options
);
return
0
;
}
case
torch
::
ScalarType
::
Double
:
{
get_workspace_size
(
maxT
,
maxU
,
minibatch_size
,
false
,
&
cpu_size_bytes
,
sizeof
(
double
));
std
::
vector
<
double
>
cpu_workspace
(
cpu_size_bytes
/
sizeof
(
double
),
0
);
compute_rnnt_loss_fp64
(
acts
.
data
<
double
>
(),
grads
.
data
<
double
>
(),
labels
.
data
<
int
>
(),
label_lengths
.
data
<
int
>
(),
input_lengths
.
data
<
int
>
(),
alphabet_size
,
minibatch_size
,
costs
.
data
<
double
>
(),
cpu_workspace
.
data
(),
options
);
return
0
;
}
default:
TORCH_CHECK
(
false
,
std
::
string
(
__func__
)
+
" not implemented for '"
+
toString
(
acts
.
scalar_type
())
+
"'"
);
}
return
-
1
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CPU
,
m
)
{
m
.
impl
(
"rnnt_loss"
,
&
cpu_rnnt_loss
);
}
#endif
torchaudio/prototype/__init__.py
0 → 100644
View file @
6b07bcf8
torchaudio/prototype/transducer.py
0 → 100644
View file @
6b07bcf8
import
torch
from
torch.autograd
import
Function
from
torch.nn
import
Module
from
torchaudio._internal
import
(
module_utils
as
_mod_utils
,
)
__all__
=
[
"rnnt_loss"
,
"RNNTLoss"
]
class
_RNNT
(
Function
):
@
staticmethod
def
forward
(
ctx
,
acts
,
labels
,
act_lens
,
label_lens
,
blank
,
reduction
):
"""
See documentation for RNNTLoss.
"""
device
=
acts
.
device
check_inputs
(
acts
,
labels
,
act_lens
,
label_lens
)
acts
=
acts
.
to
(
"cpu"
)
labels
=
labels
.
to
(
"cpu"
)
act_lens
=
act_lens
.
to
(
"cpu"
)
label_lens
=
label_lens
.
to
(
"cpu"
)
loss_func
=
torch
.
ops
.
torchaudio
.
rnnt_loss
grads
=
torch
.
zeros_like
(
acts
)
minibatch_size
=
acts
.
size
(
0
)
costs
=
torch
.
zeros
(
minibatch_size
,
dtype
=
acts
.
dtype
)
loss_func
(
acts
,
labels
,
act_lens
,
label_lens
,
costs
,
grads
,
blank
,
0
)
if
reduction
in
[
"sum"
,
"mean"
]:
costs
=
costs
.
sum
().
unsqueeze_
(
-
1
)
if
reduction
==
"mean"
:
costs
/=
minibatch_size
grads
/=
minibatch_size
costs
=
costs
.
to
(
device
)
ctx
.
grads
=
grads
.
to
(
device
)
return
costs
@
staticmethod
def
backward
(
ctx
,
grad_output
):
grad_output
=
grad_output
.
view
(
-
1
,
1
,
1
,
1
).
to
(
ctx
.
grads
)
return
ctx
.
grads
.
mul_
(
grad_output
),
None
,
None
,
None
,
None
,
None
@
_mod_utils
.
requires_module
(
"torchaudio._torchaudio"
)
def
rnnt_loss
(
acts
,
labels
,
act_lens
,
label_lens
,
blank
=
0
,
reduction
=
"mean"
):
"""Compute the RNN Transducer Loss.
The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
The implementation uses `warp-transducer <https://github.com/HawkAaron/warp-transducer>`__.
Args:
acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network
before applying ``torch.nn.functional.log_softmax``.
labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero
act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
blank (int): blank label. (Default: ``0``)
reduction (string): If ``'sum'``, the output losses will be summed.
If ``'mean'``, the output losses will be divided by the target lengths and
then the mean over the batch is taken. If ``'none'``, no reduction will be applied.
(Default: ``'mean'``)
"""
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts
=
torch
.
nn
.
functional
.
log_softmax
(
acts
,
-
1
)
return
_RNNT
.
apply
(
acts
,
labels
,
act_lens
,
label_lens
,
blank
,
reduction
)
@
_mod_utils
.
requires_module
(
"torchaudio._torchaudio"
)
class
RNNTLoss
(
Module
):
"""Compute the RNN Transducer Loss.
The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
The implementation uses `warp-transducer <https://github.com/HawkAaron/warp-transducer>`__.
Args:
blank (int): blank label. (Default: ``0``)
reduction (string): If ``'sum'``, the output losses will be summed.
If ``'mean'``, the output losses will be divided by the target lengths and
then the mean over the batch is taken. If ``'none'``, no reduction will be applied.
(Default: ``'mean'``)
"""
def
__init__
(
self
,
blank
=
0
,
reduction
=
"mean"
):
super
(
RNNTLoss
,
self
).
__init__
()
self
.
blank
=
blank
self
.
reduction
=
reduction
self
.
loss
=
_RNNT
.
apply
def
forward
(
self
,
acts
,
labels
,
act_lens
,
label_lens
):
"""
Args:
acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network
before applying ``torch.nn.functional.log_softmax``.
labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero
act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
"""
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts
=
torch
.
nn
.
functional
.
log_softmax
(
acts
,
-
1
)
return
self
.
loss
(
acts
,
labels
,
act_lens
,
label_lens
,
self
.
blank
,
self
.
reduction
)
def
check_type
(
var
,
t
,
name
):
if
var
.
dtype
is
not
t
:
raise
TypeError
(
"{} must be {}"
.
format
(
name
,
t
))
def
check_contiguous
(
var
,
name
):
if
not
var
.
is_contiguous
():
raise
ValueError
(
"{} must be contiguous"
.
format
(
name
))
def
check_dim
(
var
,
dim
,
name
):
if
len
(
var
.
shape
)
!=
dim
:
raise
ValueError
(
"{} must be {}D"
.
format
(
name
,
dim
))
def
check_inputs
(
log_probs
,
labels
,
lengths
,
label_lengths
):
check_type
(
labels
,
torch
.
int32
,
"labels"
)
check_type
(
label_lengths
,
torch
.
int32
,
"label_lengths"
)
check_type
(
lengths
,
torch
.
int32
,
"lengths"
)
check_contiguous
(
log_probs
,
"log_probs"
)
check_contiguous
(
labels
,
"labels"
)
check_contiguous
(
label_lengths
,
"label_lengths"
)
check_contiguous
(
lengths
,
"lengths"
)
if
lengths
.
shape
[
0
]
!=
log_probs
.
shape
[
0
]:
raise
ValueError
(
"must have a length per example."
)
if
label_lengths
.
shape
[
0
]
!=
log_probs
.
shape
[
0
]:
raise
ValueError
(
"must have a label length per example."
)
check_dim
(
log_probs
,
4
,
"log_probs"
)
check_dim
(
labels
,
2
,
"labels"
)
check_dim
(
lengths
,
1
,
"lenghts"
)
check_dim
(
label_lengths
,
1
,
"label_lenghts"
)
max_T
=
torch
.
max
(
lengths
)
max_U
=
torch
.
max
(
label_lengths
)
T
,
U
=
log_probs
.
shape
[
1
:
3
]
if
T
!=
max_T
:
raise
ValueError
(
"Input length mismatch"
)
if
U
!=
max_U
+
1
:
raise
ValueError
(
"Output length mismatch"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment