Unverified Commit 2c115821 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Move RNNT Loss out of prototype (#1711)

parent b7d44d97
......@@ -56,7 +56,7 @@ endif()
# Options
option(BUILD_SOX "Build libsox statically" OFF)
option(BUILD_KALDI "Build kaldi statically" ON)
option(BUILD_RNNT "Enable RNN transducer" OFF)
option(BUILD_RNNT "Enable RNN transducer" ON)
option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON)
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)
option(USE_CUDA "Enable CUDA support" OFF)
......
......@@ -82,6 +82,7 @@ conda install -y -c pytorch-nightly torchaudio
The build process builds libsox and some codecs that torchaudio need to link to. This is achieved by setting the environment variable `BUILD_SOX=1`.
The build process will fetch and build libmad, lame, flac, vorbis, opus, and libsox before building extension. This process requires `cmake` and `pkg-config`.
The build process also builds the RNN transducer loss. This functionality can be disabled by setting the environment variable `BUILD_RNNT=0`.
```bash
# Linux
......
......@@ -36,7 +36,7 @@ def _get_build(var, default=False):
_BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX")
_BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True)
_BUILD_RNNT = _get_build("BUILD_RNNT")
_BUILD_RNNT = _get_build("BUILD_RNNT", True)
_USE_ROCM = _get_build("USE_ROCM")
_USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available())
......
......@@ -256,6 +256,14 @@ vad
.. autofunction:: spectral_centroid
:hidden:`Loss`
~~~~~~~~~~~~~~
rnnt_loss
---------
.. autofunction:: rnnt_loss
References
~~~~~~~~~~
......
......@@ -39,7 +39,6 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
compliance.kaldi
kaldi_io
utils
rnnt_loss
.. toctree::
......
.. role:: hidden
:class: hidden-section
torchaudio.prototype.rnnt_loss
==============================
.. currentmodule:: torchaudio.prototype.rnnt_loss
.. note::
The RNN transducer loss is a prototype feature, see `here <https://pytorch.org/audio>`_ to learn more about the nomenclature. It is only available within the nightlies, and also needs to be imported explicitly using: :code:`from torchaudio.prototype.rnnt_loss import rnnt_loss, RNNTLoss`.
rnnt_loss
~~~~~~~~~
.. autofunction:: rnnt_loss
RNNTLoss
~~~~~~~~
.. autoclass:: RNNTLoss
.. automethod:: forward
References
~~~~~~~~~~
.. footbibliography::
......@@ -171,6 +171,12 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward
:hidden:`RNNTLoss`
~~~~~~~~~~~~~~~~~~
.. autoclass:: RNNTLoss
.. automethod:: forward
References
~~~~~~~~~~
......
......@@ -6,7 +6,7 @@ SET(BUILD_LIBTORCHAUDIO ON CACHE BOOL "Build libtorchaudio")
SET(BUILD_SOX ON CACHE BOOL "Build libsox into libtorchaudio")
SET(BUILD_KALDI OFF CACHE BOOL "Build Kaldi into libtorchaudio")
SET(BUILD_RNNT OFF CACHE BOOL "Build RNN transducer into libtorchaudio")
SET(BUILD_RNNT ON CACHE BOOL "Build RNN transducer into libtorchaudio")
SET(BUILD_TORCHAUDIO_PYTHON_EXTENSION OFF CACHE BOOL "Build Python binding")
find_package(Torch REQUIRED)
......
......@@ -22,6 +22,7 @@ cmake -GNinja \
-DCMAKE_PREFIX_PATH="$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" \
-DBUILD_SOX=ON \
-DBUILD_KALDI=OFF \
-DBUILD_RNNT=ON \
..
cmake --build .
```
......
......@@ -15,5 +15,5 @@ if [[ "$OSTYPE" == "msys" ]]; then
python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')"
"$script_dir/vc_env_helper.bat" python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag
else
BUILD_RNNT=1 BUILD_SOX=1 python setup.py bdist_wheel
BUILD_SOX=1 python setup.py bdist_wheel
fi
#!/usr/bin/env bash
set -ex
BUILD_RNNT=1 BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt
BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt
import unittest
import random
import torch
from torchaudio.prototype.rnnt_loss import RNNTLoss
import numpy as np
from torchaudio.functional import rnnt_loss
class _NumpyTransducer(torch.autograd.Function):
@staticmethod
def forward(
ctx,
log_probs,
logit_lengths,
target_lengths,
targets,
blank=-1,
):
device = log_probs.device
log_probs = log_probs.cpu().data.numpy()
logit_lengths = logit_lengths.cpu().data.numpy()
target_lengths = target_lengths.cpu().data.numpy()
targets = targets.cpu().data.numpy()
gradients, costs, _, _ = __class__.compute(
log_probs=log_probs,
logit_lengths=logit_lengths,
target_lengths=target_lengths,
targets=targets,
blank=blank,
)
costs = torch.FloatTensor(costs).to(device=device)
gradients = torch.FloatTensor(gradients).to(device=device)
ctx.grads = torch.autograd.Variable(gradients)
return costs
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul(grad_output), None, None, None, None, None, None, None, None
@staticmethod
def compute_alpha_one_sequence(log_probs, targets, blank=-1):
max_T, max_U, D = log_probs.shape
alpha = np.zeros((max_T, max_U), dtype=np.float32)
for t in range(1, max_T):
alpha[t, 0] = alpha[t - 1, 0] + log_probs[t - 1, 0, blank]
for u in range(1, max_U):
alpha[0, u] = alpha[0, u - 1] + log_probs[0, u - 1, targets[u - 1]]
for t in range(1, max_T):
for u in range(1, max_U):
skip = alpha[t - 1, u] + log_probs[t - 1, u, blank]
emit = alpha[t, u - 1] + log_probs[t, u - 1, targets[u - 1]]
alpha[t, u] = np.logaddexp(skip, emit)
cost = -(alpha[-1, -1] + log_probs[-1, -1, blank])
return alpha, cost
@staticmethod
def compute_beta_one_sequence(log_probs, targets, blank=-1):
max_T, max_U, D = log_probs.shape
beta = np.zeros((max_T, max_U), dtype=np.float32)
beta[-1, -1] = log_probs[-1, -1, blank]
for t in reversed(range(max_T - 1)):
beta[t, -1] = beta[t + 1, -1] + log_probs[t, -1, blank]
for u in reversed(range(max_U - 1)):
beta[-1, u] = beta[-1, u + 1] + log_probs[-1, u, targets[u]]
for t in reversed(range(max_T - 1)):
for u in reversed(range(max_U - 1)):
skip = beta[t + 1, u] + log_probs[t, u, blank]
emit = beta[t, u + 1] + log_probs[t, u, targets[u]]
beta[t, u] = np.logaddexp(skip, emit)
cost = -beta[0, 0]
return beta, cost
@staticmethod
def compute_gradients_one_sequence(
log_probs, alpha, beta, targets, blank=-1
):
max_T, max_U, D = log_probs.shape
gradients = np.full(log_probs.shape, float("-inf"))
cost = -beta[0, 0]
gradients[-1, -1, blank] = alpha[-1, -1]
gradients[:-1, :, blank] = alpha[:-1, :] + beta[1:, :]
for u, l in enumerate(targets):
gradients[:, u, l] = alpha[:, u] + beta[:, u + 1]
gradients = -(np.exp(gradients + log_probs + cost))
return gradients
@staticmethod
def compute(
log_probs,
logit_lengths,
target_lengths,
targets,
blank=-1,
):
gradients = np.zeros_like(log_probs)
B_tgt, max_T, max_U, D = log_probs.shape
B_src = logit_lengths.shape[0]
H = int(B_tgt / B_src)
alphas = np.zeros((B_tgt, max_T, max_U))
betas = np.zeros((B_tgt, max_T, max_U))
betas.fill(float("-inf"))
alphas.fill(float("-inf"))
costs = np.zeros(B_tgt)
for b_tgt in range(B_tgt):
b_src = int(b_tgt / H)
T = int(logit_lengths[b_src])
# NOTE: see https://arxiv.org/pdf/1211.3711.pdf Section 2.1
U = int(target_lengths[b_tgt]) + 1
seq_log_probs = log_probs[b_tgt, :T, :U, :]
seq_targets = targets[b_tgt, : int(target_lengths[b_tgt])]
alpha, alpha_cost = __class__.compute_alpha_one_sequence(
log_probs=seq_log_probs, targets=seq_targets, blank=blank
)
beta, beta_cost = __class__.compute_beta_one_sequence(
log_probs=seq_log_probs, targets=seq_targets, blank=blank
)
from .numpy_transducer import NumpyTransducerLoss
seq_gradients = __class__.compute_gradients_one_sequence(
log_probs=seq_log_probs,
alpha=alpha,
beta=beta,
targets=seq_targets,
blank=blank,
)
np.testing.assert_almost_equal(alpha_cost, beta_cost, decimal=2)
gradients[b_tgt, :T, :U, :] = seq_gradients
costs[b_tgt] = beta_cost
alphas[b_tgt, :T, :U] = alpha
betas[b_tgt, :T, :U] = beta
return gradients, costs, alphas, betas
class NumpyTransducerLoss(torch.nn.Module):
def __init__(self, blank=-1):
super().__init__()
self.blank = blank
def forward(
self,
logits,
logit_lengths,
target_lengths,
targets,
):
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
return _NumpyTransducer.apply(
log_probs,
logit_lengths,
target_lengths,
targets,
self.blank,
)
def compute_with_numpy_transducer(data):
......@@ -24,14 +189,13 @@ def compute_with_numpy_transducer(data):
def compute_with_pytorch_transducer(data):
costs = RNNTLoss(
blank=data["blank"],
reduction="none",
)(
costs = rnnt_loss(
logits=data["logits"],
logit_lengths=data["logit_lengths"],
target_lengths=data["target_lengths"],
targets=data["targets"],
blank=data["blank"],
reduction="none",
)
loss = torch.sum(costs)
......
import torch
from .autograd_impl import Autograd
from .autograd_impl import Autograd, AutogradFloat32
from torchaudio_unittest import common_utils
class TestAutogradLfilterCPU(Autograd, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestAutogradRNNTCPU(AutogradFloat32, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
import torch
from .autograd_impl import Autograd
from .autograd_impl import Autograd, AutogradFloat32
from torchaudio_unittest import common_utils
......@@ -7,3 +7,9 @@ from torchaudio_unittest import common_utils
class TestAutogradLfilterCUDA(Autograd, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestAutogradRNNTCUDA(AutogradFloat32, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
......@@ -8,6 +8,7 @@ from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
rnnt_utils,
)
......@@ -192,3 +193,45 @@ class Autograd(TestBaseMixin):
central_freq = torch.tensor(central_freq)
Q = torch.tensor(Q)
self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
class AutogradFloat32(TestBaseMixin):
def assert_grad(
self,
transform: Callable[..., Tensor],
inputs: Tuple[torch.Tensor],
enable_all_grad: bool = True,
):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.dtype, device=self.device)
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
# gradcheck with float32 requires higher atol and epsilon
assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.)
@parameterized.expand([
(rnnt_utils.get_B1_T10_U3_D4_data, ),
(rnnt_utils.get_B2_T4_U3_D3_data, ),
(rnnt_utils.get_B1_T2_U3_D5_data, ),
])
def test_rnnt_loss(self, data_func):
def get_data(data_func, device):
data = data_func()
if type(data) == tuple:
data = data[0]
return data
data = get_data(data_func, self.device)
inputs = (
data["logits"].to(torch.float32), # logits
data["targets"], # targets
data["logit_lengths"], # logit_lengths
data["target_lengths"], # target_lengths
data["blank"], # blank
-1, # clamp
)
self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False)
......@@ -9,7 +9,13 @@ import torchaudio.functional as F
from parameterized import parameterized
from scipy import signal
from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested_params, get_whitenoise
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_sinusoid,
nested_params,
get_whitenoise,
rnnt_utils,
)
class Functional(TestBaseMixin):
......@@ -42,6 +48,15 @@ class Functional(TestBaseMixin):
self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)
def _test_costs_and_gradients(
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
logits_shape = data["logits"].shape
costs, gradients = rnnt_utils.compute_with_pytorch_transducer(data=data)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
def test_lfilter_simple(self):
"""
Create a very basic signal,
......@@ -436,6 +451,50 @@ class Functional(TestBaseMixin):
waveform_shift = F.pitch_shift(waveform, sample_rate, n_steps)
assert waveform.size() == waveform_shift.size()
def test_rnnt_loss_basic_backward(self):
logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device)
loss = F.rnnt_loss(logits, targets, logit_lengths, target_lengths)
loss.backward()
def test_rnnt_loss_basic_forward_no_grad(self):
"""In early stage, calls to `rnnt_loss` resulted in segmentation fault when
`logits` have `requires_grad = False`. This test makes sure that this no longer
occurs and the functional call runs without error.
See https://github.com/pytorch/audio/pull/1707
"""
logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device)
logits.requires_grad_(False)
F.rnnt_loss(logits, targets, logit_lengths, target_lengths)
@parameterized.expand([
(rnnt_utils.get_B1_T2_U3_D5_data, torch.float32, 1e-6, 1e-2),
(rnnt_utils.get_B2_T4_U3_D3_data, torch.float32, 1e-6, 1e-2),
(rnnt_utils.get_B1_T2_U3_D5_data, torch.float16, 1e-3, 1e-2),
(rnnt_utils.get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2),
])
def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol):
data, ref_costs, ref_gradients = data_func(
dtype=dtype,
device=self.device,
)
self._test_costs_and_gradients(
data=data,
ref_costs=ref_costs,
ref_gradients=ref_gradients,
atol=atol,
rtol=rtol,
)
def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self):
seed = 777
for i in range(5):
data = rnnt_utils.get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i))
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self):
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalFloat32Only
class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
......
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalFloat32Only
@skipIfNoCuda
class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
......
......@@ -692,3 +692,21 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency_complex(func, tensor, test_paseudo_complex)
class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
def func(tensor):
targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
return F.rnnt_loss(tensor, targets, logit_lengths, target_lengths)
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
tensor = logits.to(device=self.device, dtype=torch.float32)
self._assert_consistency(func, tensor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment