Unverified Commit 6b07bcf8 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Add RNN Transducer Loss for CPU (#1137)

parent 7d00504d
...@@ -66,5 +66,6 @@ build_master() { ...@@ -66,5 +66,6 @@ build_master() {
conda install -y -q pytorch "cpuonly" -c pytorch-nightly conda install -y -q pytorch "cpuonly" -c pytorch-nightly
printf "* Installing torchaudio\n" printf "* Installing torchaudio\n"
cd "${_root_dir}" || exit 1 cd "${_root_dir}" || exit 1
BUILD_SOX=1 python setup.py clean install git submodule update --init --recursive
BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py clean install
} }
...@@ -38,7 +38,7 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit} ...@@ -38,7 +38,7 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit}
# 2. Install torchaudio # 2. Install torchaudio
printf "* Installing torchaudio\n" printf "* Installing torchaudio\n"
BUILD_SOX=1 python setup.py install BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py install
# 3. Install Test tools # 3. Install Test tools
printf "* Installing test tools\n" printf "* Installing test tools\n"
......
...@@ -43,6 +43,7 @@ conda activate "${env_dir}" ...@@ -43,6 +43,7 @@ conda activate "${env_dir}"
pip --quiet install cmake ninja pip --quiet install cmake ninja
# 4. Buld codecs # 4. Buld codecs
git submodule update --init --recursive
mkdir -p third_party/build mkdir -p third_party/build
( (
cd third_party/build cd third_party/build
......
[submodule "third_party/warp_transducer/submodule"]
path = third_party/transducer/submodule
url = https://github.com/HawkAaron/warp-transducer
ignore = dirty
...@@ -20,20 +20,21 @@ _TP_BASE_DIR = _ROOT_DIR / 'third_party' ...@@ -20,20 +20,21 @@ _TP_BASE_DIR = _ROOT_DIR / 'third_party'
_TP_INSTALL_DIR = _TP_BASE_DIR / 'install' _TP_INSTALL_DIR = _TP_BASE_DIR / 'install'
def _get_build_sox(): def _get_build(var):
val = os.environ.get('BUILD_SOX', '0') val = os.environ.get(var, '0')
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES'] trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO'] falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
if val in trues: if val in trues:
return True return True
if val not in falses: if val not in falses:
print( print(
f'WARNING: Unexpected environment variable value `BUILD_SOX={val}`. ' f'WARNING: Unexpected environment variable value `{var}={val}`. '
f'Expected one of {trues + falses}') f'Expected one of {trues + falses}')
return False return False
_BUILD_SOX = _get_build_sox() _BUILD_SOX = _get_build("BUILD_SOX")
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
def _get_eca(debug): def _get_eca(debug):
...@@ -42,6 +43,8 @@ def _get_eca(debug): ...@@ -42,6 +43,8 @@ def _get_eca(debug):
eca += ["-O0", "-g"] eca += ["-O0", "-g"]
else: else:
eca += ["-O3"] eca += ["-O3"]
if _BUILD_TRANSDUCER:
eca += ['-DBUILD_TRANSDUCER']
return eca return eca
...@@ -67,6 +70,8 @@ def _get_include_dirs(): ...@@ -67,6 +70,8 @@ def _get_include_dirs():
] ]
if _BUILD_SOX: if _BUILD_SOX:
dirs.append(str(_TP_INSTALL_DIR / 'include')) dirs.append(str(_TP_INSTALL_DIR / 'include'))
if _BUILD_TRANSDUCER:
dirs.append(str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'))
return dirs return dirs
...@@ -94,6 +99,8 @@ def _get_extra_objects(): ...@@ -94,6 +99,8 @@ def _get_extra_objects():
] ]
for lib in libs: for lib in libs:
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib)) objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
if _BUILD_TRANSDUCER:
objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a'))
return objs return objs
......
...@@ -15,5 +15,5 @@ if [[ "$OSTYPE" == "msys" ]]; then ...@@ -15,5 +15,5 @@ if [[ "$OSTYPE" == "msys" ]]; then
python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')" python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')"
python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag
else else
BUILD_SOX=1 python setup.py bdist_wheel BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py bdist_wheel
fi fi
...@@ -103,6 +103,7 @@ setup_macos() { ...@@ -103,6 +103,7 @@ setup_macos() {
# #
# Usage: setup_env 0.2.0 # Usage: setup_env 0.2.0
setup_env() { setup_env() {
git submodule update --init --recursive
setup_cuda setup_cuda
setup_build_version "$1" setup_build_version "$1"
setup_macos setup_macos
......
#!/usr/bin/env bash #!/usr/bin/env bash
set -ex set -ex
BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt
import torch
from torchaudio.prototype.transducer import RNNTLoss
from torchaudio_unittest import common_utils
def get_data_basic(device):
# Example provided
# in 6f73a2513dc784c59eec153a45f40bc528355b18
# of https://github.com/HawkAaron/warp-transducer
acts = torch.tensor(
[
[
[
[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1],
],
[
[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1],
],
]
],
dtype=torch.float,
)
labels = torch.tensor([[1, 2]], dtype=torch.int)
act_length = torch.tensor([2], dtype=torch.int)
label_length = torch.tensor([2], dtype=torch.int)
acts = acts.to(device)
labels = labels.to(device)
act_length = act_length.to(device)
label_length = label_length.to(device)
acts.requires_grad_(True)
return acts, labels, act_length, label_length
def get_data_B2_T4_U3_D3(dtype=torch.float32, device="cpu"):
# Test from D21322854
logits = torch.tensor(
[
0.065357,
0.787530,
0.081592,
0.529716,
0.750675,
0.754135,
0.609764,
0.868140,
0.622532,
0.668522,
0.858039,
0.164539,
0.989780,
0.944298,
0.603168,
0.946783,
0.666203,
0.286882,
0.094184,
0.366674,
0.736168,
0.166680,
0.714154,
0.399400,
0.535982,
0.291821,
0.612642,
0.324241,
0.800764,
0.524106,
0.779195,
0.183314,
0.113745,
0.240222,
0.339470,
0.134160,
0.505562,
0.051597,
0.640290,
0.430733,
0.829473,
0.177467,
0.320700,
0.042883,
0.302803,
0.675178,
0.569537,
0.558474,
0.083132,
0.060165,
0.107958,
0.748615,
0.943918,
0.486356,
0.418199,
0.652408,
0.024243,
0.134582,
0.366342,
0.295830,
0.923670,
0.689929,
0.741898,
0.250005,
0.603430,
0.987289,
0.592606,
0.884672,
0.543450,
0.660770,
0.377128,
0.358021,
],
dtype=dtype,
).reshape(2, 4, 3, 3)
targets = torch.tensor([[1, 2], [1, 1]], dtype=torch.int32)
src_lengths = torch.tensor([4, 4], dtype=torch.int32)
tgt_lengths = torch.tensor([2, 2], dtype=torch.int32)
blank = 0
ref_costs = torch.tensor([4.2806528590890736, 3.9384369822503591], dtype=dtype)
ref_gradients = torch.tensor(
[
-0.186844,
-0.062555,
0.249399,
-0.203377,
0.202399,
0.000977,
-0.141016,
0.079123,
0.061893,
-0.011552,
-0.081280,
0.092832,
-0.154257,
0.229433,
-0.075176,
-0.246593,
0.146405,
0.100188,
-0.012918,
-0.061593,
0.074512,
-0.055986,
0.219831,
-0.163845,
-0.497627,
0.209240,
0.288387,
0.013605,
-0.030220,
0.016615,
0.113925,
0.062781,
-0.176706,
-0.667078,
0.367659,
0.299419,
-0.356344,
-0.055347,
0.411691,
-0.096922,
0.029459,
0.067463,
-0.063518,
0.027654,
0.035863,
-0.154499,
-0.073942,
0.228441,
-0.166790,
-0.000088,
0.166878,
-0.172370,
0.105565,
0.066804,
0.023875,
-0.118256,
0.094381,
-0.104707,
-0.108934,
0.213642,
-0.369844,
0.180118,
0.189726,
0.025714,
-0.079462,
0.053748,
0.122328,
-0.238789,
0.116460,
-0.598687,
0.302203,
0.296484,
],
dtype=dtype,
).reshape(2, 4, 3, 3)
logits.requires_grad_(True)
logits = logits.to(device)
def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
data = {
"logits": logits,
"targets": targets,
"src_lengths": src_lengths,
"tgt_lengths": tgt_lengths,
"blank": blank,
}
return data, ref_costs, ref_gradients
def compute_with_pytorch_transducer(data):
costs = RNNTLoss(blank=data["blank"], reduction="none")(
acts=data["logits"],
labels=data["targets"],
act_lens=data["src_lengths"],
label_lens=data["tgt_lengths"],
)
loss = torch.sum(costs)
loss.backward()
costs = costs.cpu()
gradients = data["logits"].saved_grad.cpu()
return costs, gradients
class TransducerTester:
def test_basic_fp16_error(self):
rnnt_loss = RNNTLoss()
acts, labels, act_length, label_length = get_data_basic(self.device)
acts = acts.to(torch.float16)
# RuntimeError raised by log_softmax before reaching transducer's bindings
self.assertRaises(
RuntimeError, rnnt_loss, acts, labels, act_length, label_length
)
def test_basic_backward(self):
rnnt_loss = RNNTLoss()
acts, labels, act_length, label_length = get_data_basic(self.device)
loss = rnnt_loss(acts, labels, act_length, label_length)
loss.backward()
def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
data, ref_costs, ref_gradients = get_data_B2_T4_U3_D3(
dtype=torch.float32, device=self.device
)
logits_shape = data["logits"].shape
costs, gradients = compute_with_pytorch_transducer(data=data)
atol, rtol = 1e-6, 1e-2
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
@common_utils.skipIfNoExtension
class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase):
device = "cpu"
...@@ -88,3 +88,5 @@ ExternalProject_Add(libsox ...@@ -88,3 +88,5 @@ ExternalProject_Add(libsox
# See https://github.com/pytorch/audio/pull/1026 # See https://github.com/pytorch/audio/pull/1026
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp --without-sndio --without-pulseaudio CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp --without-sndio --without-pulseaudio
) )
add_subdirectory(transducer)
CMAKE_MINIMUM_REQUIRED(VERSION 3.5)
PROJECT(rnnt_release)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")
IF(APPLE)
ADD_DEFINITIONS(-DAPPLE)
ENDIF()
INCLUDE_DIRECTORIES(submodule/include)
SET(CMAKE_POSITION_INDEPENDENT_CODE ON)
ADD_DEFINITIONS(-DRNNT_DISABLE_OMP)
IF(APPLE)
EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION)
STRING(REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION})
MESSAGE(STATUS "DARWIN_VERSION=${DARWIN_VERSION}")
# for el capitain have to use rpath
IF(DARWIN_VERSION LESS 15)
SET(CMAKE_SKIP_RPATH TRUE)
ENDIF()
ELSE()
# always skip for linux
SET(CMAKE_SKIP_RPATH TRUE)
ENDIF()
ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp)
INSTALL(TARGETS warprnnt
LIBRARY DESTINATION "lib"
ARCHIVE DESTINATION "lib")
INSTALL(FILES submodule/include/rnnt.h DESTINATION "submodule/include")
Subproject commit f546575109111c455354861a0567c8aa794208a2
...@@ -77,5 +77,19 @@ TORCH_LIBRARY(torchaudio, m) { ...@@ -77,5 +77,19 @@ TORCH_LIBRARY(torchaudio, m) {
m.def( m.def(
"torchaudio::sox_effects_apply_effects_file", "torchaudio::sox_effects_apply_effects_file",
&torchaudio::sox_effects::apply_effects_file); &torchaudio::sox_effects::apply_effects_file);
//////////////////////////////////////////////////////////////////////////////
// transducer.cpp
//////////////////////////////////////////////////////////////////////////////
#ifdef BUILD_TRANSDUCER
m.def("rnnt_loss(Tensor acts,"
"Tensor labels,"
"Tensor input_lengths,"
"Tensor label_lengths,"
"Tensor costs,"
"Tensor grads,"
"int blank_label,"
"int num_threads) -> int");
#endif
} }
#endif #endif
#ifdef BUILD_TRANSDUCER
#include <iostream>
#include <numeric>
#include <string>
#include <vector>
#include <torch/script.h>
#include "rnnt.h"
int64_t cpu_rnnt_loss(torch::Tensor acts,
torch::Tensor labels,
torch::Tensor input_lengths,
torch::Tensor label_lengths,
torch::Tensor costs,
torch::Tensor grads,
int64_t blank_label,
int64_t num_threads) {
int maxT = acts.size(1);
int maxU = acts.size(2);
int minibatch_size = acts.size(0);
int alphabet_size = acts.size(3);
rnntOptions options;
memset(&options, 0, sizeof(options));
options.maxT = maxT;
options.maxU = maxU;
options.blank_label = blank_label;
options.batch_first = true;
options.loc = RNNT_CPU;
options.num_threads = num_threads;
// have to use at least one
options.num_threads = std::max(options.num_threads, (unsigned int) 1);
size_t cpu_size_bytes = 0;
switch (acts.scalar_type()) {
case torch::ScalarType::Float:
{
get_workspace_size(maxT, maxU, minibatch_size,
false, &cpu_size_bytes);
std::vector<float> cpu_workspace(cpu_size_bytes / sizeof(float), 0);
compute_rnnt_loss(acts.data<float>(), grads.data<float>(),
labels.data<int>(), label_lengths.data<int>(),
input_lengths.data<int>(), alphabet_size,
minibatch_size, costs.data<float>(),
cpu_workspace.data(), options);
return 0;
}
case torch::ScalarType::Double:
{
get_workspace_size(maxT, maxU, minibatch_size,
false, &cpu_size_bytes,
sizeof(double));
std::vector<double> cpu_workspace(cpu_size_bytes / sizeof(double), 0);
compute_rnnt_loss_fp64(acts.data<double>(), grads.data<double>(),
labels.data<int>(), label_lengths.data<int>(),
input_lengths.data<int>(), alphabet_size,
minibatch_size, costs.data<double>(),
cpu_workspace.data(), options);
return 0;
}
default:
TORCH_CHECK(false,
std::string(__func__) + " not implemented for '" + toString(acts.scalar_type()) + "'"
);
}
return -1;
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss", &cpu_rnnt_loss);
}
#endif
import torch
from torch.autograd import Function
from torch.nn import Module
from torchaudio._internal import (
module_utils as _mod_utils,
)
__all__ = ["rnnt_loss", "RNNTLoss"]
class _RNNT(Function):
@staticmethod
def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction):
"""
See documentation for RNNTLoss.
"""
device = acts.device
check_inputs(acts, labels, act_lens, label_lens)
acts = acts.to("cpu")
labels = labels.to("cpu")
act_lens = act_lens.to("cpu")
label_lens = label_lens.to("cpu")
loss_func = torch.ops.torchaudio.rnnt_loss
grads = torch.zeros_like(acts)
minibatch_size = acts.size(0)
costs = torch.zeros(minibatch_size, dtype=acts.dtype)
loss_func(acts, labels, act_lens, label_lens, costs, grads, blank, 0)
if reduction in ["sum", "mean"]:
costs = costs.sum().unsqueeze_(-1)
if reduction == "mean":
costs /= minibatch_size
grads /= minibatch_size
costs = costs.to(device)
ctx.grads = grads.to(device)
return costs
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul_(grad_output), None, None, None, None, None
@_mod_utils.requires_module("torchaudio._torchaudio")
def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"):
"""Compute the RNN Transducer Loss.
The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
The implementation uses `warp-transducer <https://github.com/HawkAaron/warp-transducer>`__.
Args:
acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network
before applying ``torch.nn.functional.log_softmax``.
labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero
act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
blank (int): blank label. (Default: ``0``)
reduction (string): If ``'sum'``, the output losses will be summed.
If ``'mean'``, the output losses will be divided by the target lengths and
then the mean over the batch is taken. If ``'none'``, no reduction will be applied.
(Default: ``'mean'``)
"""
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)
return _RNNT.apply(acts, labels, act_lens, label_lens, blank, reduction)
@_mod_utils.requires_module("torchaudio._torchaudio")
class RNNTLoss(Module):
"""Compute the RNN Transducer Loss.
The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
The implementation uses `warp-transducer <https://github.com/HawkAaron/warp-transducer>`__.
Args:
blank (int): blank label. (Default: ``0``)
reduction (string): If ``'sum'``, the output losses will be summed.
If ``'mean'``, the output losses will be divided by the target lengths and
then the mean over the batch is taken. If ``'none'``, no reduction will be applied.
(Default: ``'mean'``)
"""
def __init__(self, blank=0, reduction="mean"):
super(RNNTLoss, self).__init__()
self.blank = blank
self.reduction = reduction
self.loss = _RNNT.apply
def forward(self, acts, labels, act_lens, label_lens):
"""
Args:
acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network
before applying ``torch.nn.functional.log_softmax``.
labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero
act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence
"""
# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)
return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction)
def check_type(var, t, name):
if var.dtype is not t:
raise TypeError("{} must be {}".format(name, t))
def check_contiguous(var, name):
if not var.is_contiguous():
raise ValueError("{} must be contiguous".format(name))
def check_dim(var, dim, name):
if len(var.shape) != dim:
raise ValueError("{} must be {}D".format(name, dim))
def check_inputs(log_probs, labels, lengths, label_lengths):
check_type(labels, torch.int32, "labels")
check_type(label_lengths, torch.int32, "label_lengths")
check_type(lengths, torch.int32, "lengths")
check_contiguous(log_probs, "log_probs")
check_contiguous(labels, "labels")
check_contiguous(label_lengths, "label_lengths")
check_contiguous(lengths, "lengths")
if lengths.shape[0] != log_probs.shape[0]:
raise ValueError("must have a length per example.")
if label_lengths.shape[0] != log_probs.shape[0]:
raise ValueError("must have a label length per example.")
check_dim(log_probs, 4, "log_probs")
check_dim(labels, 2, "labels")
check_dim(lengths, 1, "lenghts")
check_dim(label_lengths, 1, "label_lenghts")
max_T = torch.max(lengths)
max_U = torch.max(label_lengths)
T, U = log_probs.shape[1:3]
if T != max_T:
raise ValueError("Input length mismatch")
if U != max_U + 1:
raise ValueError("Output length mismatch")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment