"stubs/vscode:/vscode.git/clone" did not exist on "bde4bac5fe3f6b040ac6d75e3bd631be7f504c27"
Unverified Commit 2c115821 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Move RNNT Loss out of prototype (#1711)

parent b7d44d97
import torch
from .autograd_impl import Autograd
from torchaudio_unittest import common_utils
from .utils import skipIfNoRNNT
@skipIfNoRNNT
class TestAutograd(Autograd, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
import torch
from .autograd_impl import Autograd
from torchaudio_unittest import common_utils
from .utils import skipIfNoRNNT
@skipIfNoRNNT
@common_utils.skipIfNoCuda
class TestAutograd(Autograd, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
from typing import Callable, Tuple
import torch
from torch import Tensor
from torch.autograd import gradcheck
from torchaudio_unittest.common_utils import (
TestBaseMixin,
)
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss
from parameterized import parameterized
from .utils import (
get_B1_T10_U3_D4_data,
get_B2_T4_U3_D3_data,
get_B1_T2_U3_D5_data
)
from .numpy_transducer import NumpyTransducerLoss
class Autograd(TestBaseMixin):
@staticmethod
def get_data(data_func, device):
data = data_func()
if type(data) == tuple:
data = data[0]
return data
def assert_grad(
self,
loss: Callable[..., Tensor],
inputs: Tuple[torch.Tensor],
*,
enable_all_grad: bool = True,
):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.dtype, device=self.device)
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
# gradcheck with float32 requires higher atol and epsilon
assert gradcheck(loss, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.)
@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_B2_T4_U3_D3_data, ),
(get_B1_T2_U3_D5_data, ),
])
def test_RNNTLoss_gradcheck(self, data_func):
data = self.get_data(data_func, self.device)
inputs = (
data["logits"].to(self.dtype),
data["targets"],
data["logit_lengths"],
data["target_lengths"],
)
loss = RNNTLoss(blank=data["blank"])
self.assert_grad(loss, inputs, enable_all_grad=False)
@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_B2_T4_U3_D3_data, ),
(get_B1_T2_U3_D5_data, ),
])
def test_rnnt_loss_gradcheck(self, data_func):
data = self.get_data(data_func, self.device)
inputs = (
data["logits"].to(self.dtype), # logits
data["targets"], # targets
data["logit_lengths"], # logit_lengths
data["target_lengths"], # target_lengths
data["blank"], # blank
-1, # clamp
)
self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)
@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_B2_T4_U3_D3_data, ),
(get_B1_T2_U3_D5_data, ),
])
def test_np_transducer_gradcheck(self, data_func):
data = self.get_data(data_func, self.device)
inputs = (
data["logits"].to(self.dtype),
data["logit_lengths"],
data["target_lengths"],
data["targets"],
)
loss = NumpyTransducerLoss(blank=data["blank"])
self.assert_grad(loss, inputs, enable_all_grad=False)
import numpy as np
import torch
class _NumpyTransducer(torch.autograd.Function):
@staticmethod
def forward(
ctx,
log_probs,
logit_lengths,
target_lengths,
targets,
blank=-1,
):
device = log_probs.device
log_probs = log_probs.cpu().data.numpy()
logit_lengths = logit_lengths.cpu().data.numpy()
target_lengths = target_lengths.cpu().data.numpy()
targets = targets.cpu().data.numpy()
gradients, costs, _, _ = __class__.compute(
log_probs=log_probs,
logit_lengths=logit_lengths,
target_lengths=target_lengths,
targets=targets,
blank=blank,
)
costs = torch.FloatTensor(costs).to(device=device)
gradients = torch.FloatTensor(gradients).to(device=device)
ctx.grads = torch.autograd.Variable(gradients)
return costs
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul(grad_output), None, None, None, None, None, None, None, None
@staticmethod
def compute_alpha_one_sequence(log_probs, targets, blank=-1):
max_T, max_U, D = log_probs.shape
alpha = np.zeros((max_T, max_U), dtype=np.float32)
for t in range(1, max_T):
alpha[t, 0] = alpha[t - 1, 0] + log_probs[t - 1, 0, blank]
for u in range(1, max_U):
alpha[0, u] = alpha[0, u - 1] + log_probs[0, u - 1, targets[u - 1]]
for t in range(1, max_T):
for u in range(1, max_U):
skip = alpha[t - 1, u] + log_probs[t - 1, u, blank]
emit = alpha[t, u - 1] + log_probs[t, u - 1, targets[u - 1]]
alpha[t, u] = np.logaddexp(skip, emit)
cost = -(alpha[-1, -1] + log_probs[-1, -1, blank])
return alpha, cost
@staticmethod
def compute_beta_one_sequence(log_probs, targets, blank=-1):
max_T, max_U, D = log_probs.shape
beta = np.zeros((max_T, max_U), dtype=np.float32)
beta[-1, -1] = log_probs[-1, -1, blank]
for t in reversed(range(max_T - 1)):
beta[t, -1] = beta[t + 1, -1] + log_probs[t, -1, blank]
for u in reversed(range(max_U - 1)):
beta[-1, u] = beta[-1, u + 1] + log_probs[-1, u, targets[u]]
for t in reversed(range(max_T - 1)):
for u in reversed(range(max_U - 1)):
skip = beta[t + 1, u] + log_probs[t, u, blank]
emit = beta[t, u + 1] + log_probs[t, u, targets[u]]
beta[t, u] = np.logaddexp(skip, emit)
cost = -beta[0, 0]
return beta, cost
@staticmethod
def compute_gradients_one_sequence(
log_probs, alpha, beta, targets, blank=-1
):
max_T, max_U, D = log_probs.shape
gradients = np.full(log_probs.shape, float("-inf"))
cost = -beta[0, 0]
gradients[-1, -1, blank] = alpha[-1, -1]
gradients[:-1, :, blank] = alpha[:-1, :] + beta[1:, :]
for u, l in enumerate(targets):
gradients[:, u, l] = alpha[:, u] + beta[:, u + 1]
gradients = -(np.exp(gradients + log_probs + cost))
return gradients
@staticmethod
def compute(
log_probs,
logit_lengths,
target_lengths,
targets,
blank=-1,
):
gradients = np.zeros_like(log_probs)
B_tgt, max_T, max_U, D = log_probs.shape
B_src = logit_lengths.shape[0]
H = int(B_tgt / B_src)
alphas = np.zeros((B_tgt, max_T, max_U))
betas = np.zeros((B_tgt, max_T, max_U))
betas.fill(float("-inf"))
alphas.fill(float("-inf"))
costs = np.zeros(B_tgt)
for b_tgt in range(B_tgt):
b_src = int(b_tgt / H)
T = int(logit_lengths[b_src])
# NOTE: see https://arxiv.org/pdf/1211.3711.pdf Section 2.1
U = int(target_lengths[b_tgt]) + 1
seq_log_probs = log_probs[b_tgt, :T, :U, :]
seq_targets = targets[b_tgt, : int(target_lengths[b_tgt])]
alpha, alpha_cost = __class__.compute_alpha_one_sequence(
log_probs=seq_log_probs, targets=seq_targets, blank=blank
)
beta, beta_cost = __class__.compute_beta_one_sequence(
log_probs=seq_log_probs, targets=seq_targets, blank=blank
)
seq_gradients = __class__.compute_gradients_one_sequence(
log_probs=seq_log_probs,
alpha=alpha,
beta=beta,
targets=seq_targets,
blank=blank,
)
np.testing.assert_almost_equal(alpha_cost, beta_cost, decimal=2)
gradients[b_tgt, :T, :U, :] = seq_gradients
costs[b_tgt] = beta_cost
alphas[b_tgt, :T, :U] = alpha
betas[b_tgt, :T, :U] = beta
return gradients, costs, alphas, betas
class NumpyTransducerLoss(torch.nn.Module):
def __init__(self, blank=-1):
super().__init__()
self.blank = blank
def forward(
self,
logits,
logit_lengths,
target_lengths,
targets,
):
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
return _NumpyTransducer.apply(
log_probs,
logit_lengths,
target_lengths,
targets,
self.blank,
)
import torch
from torchaudio_unittest import common_utils
from .utils import skipIfNoRNNT
from .rnnt_loss_impl import RNNTLossTest
@skipIfNoRNNT
class TestRNNTLoss(RNNTLossTest, common_utils.PytorchTestCase):
device = torch.device('cpu')
import torch
from .rnnt_loss_impl import RNNTLossTest
from torchaudio_unittest import common_utils
from .utils import skipIfNoRNNT
@skipIfNoRNNT
@common_utils.skipIfNoCuda
class TestRNNTLoss(RNNTLossTest, common_utils.PytorchTestCase):
device = torch.device('cuda')
import torch
from torchaudio.prototype.rnnt_loss import RNNTLoss
from .utils import (
compute_with_numpy_transducer,
compute_with_pytorch_transducer,
get_basic_data,
get_B1_T2_U3_D5_data,
get_B2_T4_U3_D3_data,
get_random_data,
)
class RNNTLossTest:
def _test_costs_and_gradients(
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
logits_shape = data["logits"].shape
costs, gradients = compute_with_pytorch_transducer(data=data)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
def test_basic_backward(self):
rnnt_loss = RNNTLoss()
logits, targets, logit_lengths, target_lengths = get_basic_data(self.device)
loss = rnnt_loss(logits, targets, logit_lengths, target_lengths)
loss.backward()
def test_basic_forward_no_grad(self):
rnnt_loss = RNNTLoss()
logits, targets, logit_lengths, target_lengths = get_basic_data(self.device)
logits.requires_grad_(False)
rnnt_loss(logits, targets, logit_lengths, target_lengths)
def test_costs_and_gradients_B1_T2_U3_D5_fp32(self):
data, ref_costs, ref_gradients = get_B1_T2_U3_D5_data(
dtype=torch.float32,
device=self.device,
)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
def test_costs_and_gradients_B1_T2_U3_D5_fp16(self):
data, ref_costs, ref_gradients = get_B1_T2_U3_D5_data(
dtype=torch.float16,
device=self.device,
)
self._test_costs_and_gradients(
data=data,
ref_costs=ref_costs,
ref_gradients=ref_gradients,
atol=1e-3,
rtol=1e-2,
)
def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
data, ref_costs, ref_gradients = get_B2_T4_U3_D3_data(
dtype=torch.float32,
device=self.device,
)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
def test_costs_and_gradients_B2_T4_U3_D3_fp16(self):
data, ref_costs, ref_gradients = get_B2_T4_U3_D3_data(
dtype=torch.float16,
device=self.device,
)
self._test_costs_and_gradients(
data=data,
ref_costs=ref_costs,
ref_gradients=ref_gradients,
atol=1e-3,
rtol=1e-2,
)
def test_costs_and_gradients_random_data_with_numpy_fp32(self):
seed = 777
for i in range(5):
data = get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i))
ref_costs, ref_gradients = compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .utils import skipIfNoRNNT
from .torchscript_consistency_impl import RNNTLossTorchscript
@skipIfNoRNNT
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cpu')
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .utils import skipIfNoRNNT
from .torchscript_consistency_impl import RNNTLossTorchscript
@skipIfNoRNNT
@skipIfNoCuda
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cuda')
import torch
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss
class RNNTLossTorchscript(TempDirMixin, TestBaseMixin):
"""Implements test for RNNT Loss that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor)
self.assertEqual(ts_output, output)
def test_rnnt_loss(self):
def func(
logits,
):
targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
return rnnt_loss(logits, targets, logit_lengths, target_lengths)
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
self._assert_consistency(func, logits)
def test_RNNTLoss(self):
func = RNNTLoss()
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
targets = torch.tensor([[1, 2]], device=self.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)
tensor = logits.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor, targets, logit_lengths, target_lengths)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths)
self.assertEqual(ts_output, output)
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .autograd_test_impl import AutogradTestMixin from .autograd_test_impl import AutogradTestMixin, AutogradTestFloat32
class AutogradCPUTest(AutogradTestMixin, PytorchTestCase): class AutogradCPUTest(AutogradTestMixin, PytorchTestCase):
device = 'cpu' device = 'cpu'
class AutogradRNNTCPUTest(AutogradTestFloat32, PytorchTestCase):
device = 'cpu'
...@@ -2,9 +2,14 @@ from torchaudio_unittest.common_utils import ( ...@@ -2,9 +2,14 @@ from torchaudio_unittest.common_utils import (
PytorchTestCase, PytorchTestCase,
skipIfNoCuda, skipIfNoCuda,
) )
from .autograd_test_impl import AutogradTestMixin from .autograd_test_impl import AutogradTestMixin, AutogradTestFloat32
@skipIfNoCuda @skipIfNoCuda
class AutogradCUDATest(AutogradTestMixin, PytorchTestCase): class AutogradCUDATest(AutogradTestMixin, PytorchTestCase):
device = 'cuda' device = 'cuda'
@skipIfNoCuda
class AutogradRNNTCUDATest(AutogradTestFloat32, PytorchTestCase):
device = 'cuda'
...@@ -11,6 +11,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -11,6 +11,7 @@ from torchaudio_unittest.common_utils import (
get_whitenoise, get_whitenoise,
get_spectrogram, get_spectrogram,
nested_params, nested_params,
rnnt_utils,
) )
...@@ -260,3 +261,41 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -260,3 +261,41 @@ class AutogradTestMixin(TestBaseMixin):
if test_pseudo_complex: if test_pseudo_complex:
spectrogram = torch.view_as_real(spectrogram) spectrogram = torch.view_as_real(spectrogram)
self.assert_grad(transform, [spectrogram]) self.assert_grad(transform, [spectrogram])
class AutogradTestFloat32(TestBaseMixin):
def assert_grad(
self,
transform: torch.nn.Module,
inputs: List[torch.Tensor],
):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=torch.float32, device=self.device)
inputs_.append(i)
# gradcheck with float32 requires higher atol and epsilon
assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.)
@parameterized.expand([
(rnnt_utils.get_B1_T10_U3_D4_data, ),
(rnnt_utils.get_B2_T4_U3_D3_data, ),
(rnnt_utils.get_B1_T2_U3_D5_data, ),
])
def test_rnnt_loss(self, data_func):
def get_data(data_func, device):
data = data_func()
if type(data) == tuple:
data = data[0]
return data
data = get_data(data_func, self.device)
inputs = (
data["logits"].to(torch.float32),
data["targets"],
data["logit_lengths"],
data["target_lengths"],
)
loss = T.RNNTLoss(blank=data["blank"])
self.assert_grad(loss, inputs)
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
class TestTransformsFloat32(Transforms, PytorchTestCase): class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device('cpu')
......
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
@skipIfNoCuda @skipIfNoCuda
class TestTransformsFloat32(Transforms, PytorchTestCase): class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cuda') device = torch.device('cuda')
......
...@@ -14,7 +14,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -14,7 +14,7 @@ from torchaudio_unittest.common_utils import (
class Transforms(TempDirMixin, TestBaseMixin): class Transforms(TempDirMixin, TestBaseMixin):
"""Implements test for Transforms that are performed for different devices""" """Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor): def _assert_consistency(self, transform, tensor, *args):
tensor = tensor.to(device=self.device, dtype=self.dtype) tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype) transform = transform.to(device=self.device, dtype=self.dtype)
...@@ -22,8 +22,8 @@ class Transforms(TempDirMixin, TestBaseMixin): ...@@ -22,8 +22,8 @@ class Transforms(TempDirMixin, TestBaseMixin):
torch.jit.script(transform).save(path) torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path) ts_transform = torch.jit.load(path)
output = transform(tensor) output = transform(tensor, *args)
ts_output = ts_transform(tensor) ts_output = ts_transform(tensor, *args)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False): def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False):
...@@ -155,3 +155,19 @@ class Transforms(TempDirMixin, TestBaseMixin): ...@@ -155,3 +155,19 @@ class Transforms(TempDirMixin, TestBaseMixin):
T.PitchShift(sample_rate=sample_rate, n_steps=n_steps), T.PitchShift(sample_rate=sample_rate, n_steps=n_steps),
waveform waveform
) )
class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
tensor = logits.to(device=self.device, dtype=torch.float32)
targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
self._assert_consistency(T.RNNTLoss(), logits, targets, logit_lengths, target_lengths)
...@@ -25,6 +25,7 @@ from .functional import ( ...@@ -25,6 +25,7 @@ from .functional import (
resample, resample,
edit_distance, edit_distance,
pitch_shift, pitch_shift,
rnnt_loss,
) )
from .filtering import ( from .filtering import (
allpass_biquad, allpass_biquad,
...@@ -98,4 +99,5 @@ __all__ = [ ...@@ -98,4 +99,5 @@ __all__ = [
'resample', 'resample',
'edit_distance', 'edit_distance',
'pitch_shift', 'pitch_shift',
'rnnt_loss',
] ]
...@@ -40,6 +40,7 @@ __all__ = [ ...@@ -40,6 +40,7 @@ __all__ = [
"resample", "resample",
"edit_distance", "edit_distance",
"pitch_shift", "pitch_shift",
"rnnt_loss",
] ]
...@@ -1745,3 +1746,55 @@ def pitch_shift( ...@@ -1745,3 +1746,55 @@ def pitch_shift(
# unpack batch # unpack batch
waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:]) waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:])
return waveform_shift return waveform_shift
def rnnt_loss(
logits: Tensor,
targets: Tensor,
logit_lengths: Tensor,
target_lengths: Tensor,
blank: int = -1,
clamp: float = -1,
reduction: str = "mean",
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
[:footcite:`graves2012sequence`].
The RNN Transducer loss extends the CTC loss by defining a distribution over output
sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
Args:
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
if reduction not in ['none', 'mean', 'sum']:
raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")
if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank
costs, _ = torch.ops.torchaudio.rnnt_loss(
logits=logits,
targets=targets,
logit_lengths=logit_lengths,
target_lengths=target_lengths,
blank=blank,
clamp=clamp,
)
if reduction == 'mean':
return costs.mean()
elif reduction == 'sum':
return costs.sum()
return costs
import torch
from torch import Tensor
__all__ = [
"RNNTLoss",
"rnnt_loss",
]
def rnnt_loss(
logits: Tensor,
targets: Tensor,
logit_lengths: Tensor,
target_lengths: Tensor,
blank: int = -1,
clamp: float = -1,
reduction: str = "mean",
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
[:footcite:`graves2012sequence`].
The RNN Transducer loss extends the CTC loss by defining a distribution over output
sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
Args:
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
if reduction not in ['none', 'mean', 'sum']:
raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")
if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank
costs, _ = torch.ops.torchaudio.rnnt_loss(
logits=logits,
targets=targets,
logit_lengths=logit_lengths,
target_lengths=target_lengths,
blank=blank,
clamp=clamp,
)
if reduction == 'mean':
return costs.mean()
elif reduction == 'sum':
return costs.sum()
return costs
class RNNTLoss(torch.nn.Module):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
[:footcite:`graves2012sequence`].
The RNN Transducer loss extends the CTC loss by defining a distribution over output
sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
Args:
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
def __init__(
self,
blank: int = -1,
clamp: float = -1.,
reduction: str = "mean",
):
super().__init__()
self.blank = blank
self.clamp = clamp
self.reduction = reduction
def forward(
self,
logits,
targets,
logit_lengths,
target_lengths,
):
"""
Args:
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
return rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
self.blank,
self.clamp,
self.reduction
)
...@@ -37,6 +37,7 @@ __all__ = [ ...@@ -37,6 +37,7 @@ __all__ = [
'Vol', 'Vol',
'ComputeDeltas', 'ComputeDeltas',
'PitchShift', 'PitchShift',
'RNNTLoss',
] ]
...@@ -1428,3 +1429,57 @@ class PitchShift(torch.nn.Module): ...@@ -1428,3 +1429,57 @@ class PitchShift(torch.nn.Module):
return F.pitch_shift(waveform, self.sample_rate, self.n_steps, self.bins_per_octave, self.n_fft, return F.pitch_shift(waveform, self.sample_rate, self.n_steps, self.bins_per_octave, self.n_fft,
self.win_length, self.hop_length, self.window) self.win_length, self.hop_length, self.window)
class RNNTLoss(torch.nn.Module):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
[:footcite:`graves2012sequence`].
The RNN Transducer loss extends the CTC loss by defining a distribution over output
sequences of all lengths, and by jointly modelling both input-output and output-output
dependencies.
Args:
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
def __init__(
self,
blank: int = -1,
clamp: float = -1.,
reduction: str = "mean",
):
super().__init__()
self.blank = blank
self.clamp = clamp
self.reduction = reduction
def forward(
self,
logits: Tensor,
targets: Tensor,
logit_lengths: Tensor,
target_lengths: Tensor,
):
"""
Args:
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch),
otherwise scalar.
"""
return F.rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
self.blank,
self.clamp,
self.reduction
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment