Unverified Commit 0c263a93 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Replace existing prototype RNNT Loss (#1479)

Replace the prototype RNNT implementation (using warp-transducer) with one without external library dependencies
parent b5d80279
[submodule "third_party/warp_transducer/submodule"]
path = third_party/transducer/submodule
url = https://github.com/HawkAaron/warp-transducer
ignore = dirty
[submodule "kaldi"]
path = third_party/kaldi/submodule
url = https://github.com/kaldi-asr/kaldi
......
......@@ -21,7 +21,7 @@ Features described in this documentation are classified by release status:
*Prototype:* These features are typically not available as part of
binary distributions like PyPI or Conda, except sometimes behind run-time
flags, and are at an early stage for feedback and testing.
The :mod:`torchaudio` package consists of I/O, popular datasets and common audio transformations.
......@@ -39,9 +39,9 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
compliance.kaldi
kaldi_io
utils
transducer
rnnt_loss
.. toctree::
:maxdepth: 1
:caption: PyTorch Libraries
......
.. role:: hidden
:class: hidden-section
torchaudio.prototype.transducer
torchaudio.prototype.rnnt_loss
===============================
.. currentmodule:: torchaudio.prototype.transducer
.. currentmodule:: torchaudio.prototype.rnnt_loss
.. note::
The RNN transducer loss is a prototype feature, see `here <https://pytorch.org/audio>`_ to learn more about the nomenclature. It is only available within the nightlies, and also needs to be imported explicitly using: :code:`from torchaudio.prototype.transducer import rnnt_loss, RNNTLoss`.
The RNN transducer loss is a prototype feature, see `here <https://pytorch.org/audio>`_ to learn more about the nomenclature. It is only available within the nightlies, and also needs to be imported explicitly using: :code:`from torchaudio.prototype.rnnt_loss import rnnt_loss, RNNTLoss`.
rnnt_loss
---------
......
......@@ -6,7 +6,7 @@ SET(BUILD_LIBTORCHAUDIO ON CACHE BOOL "Build libtorchaudio")
SET(BUILD_SOX ON CACHE BOOL "Build libsox into libtorchaudio")
SET(BUILD_KALDI OFF CACHE BOOL "Build Kaldi into libtorchaudio")
SET(BUILD_TRANSDUCER OFF CACHE BOOL "Build Python binding")
SET(BUILD_TRANSDUCER OFF CACHE BOOL "Build transducer into libtorchaudio")
SET(BUILD_TORCHAUDIO_PYTHON_EXTENSION OFF CACHE BOOL "Build Python binding")
find_package(Torch REQUIRED)
......
import numpy as np
import torch
class _NumpyTransducer(torch.autograd.Function):
@staticmethod
def forward(
ctx,
log_probs,
logit_lengths,
target_lengths,
targets,
blank=-1,
):
device = log_probs.device
log_probs = log_probs.cpu().data.numpy()
logit_lengths = logit_lengths.cpu().data.numpy()
target_lengths = target_lengths.cpu().data.numpy()
targets = targets.cpu().data.numpy()
gradients, costs, _, _ = __class__.compute(
log_probs=log_probs,
logit_lengths=logit_lengths,
target_lengths=target_lengths,
targets=targets,
blank=blank,
)
costs = torch.FloatTensor(costs).to(device=device)
gradients = torch.FloatTensor(gradients).to(device=device)
ctx.grads = torch.autograd.Variable(gradients)
return costs
@staticmethod
def backward(ctx, output_gradients):
return ctx.grads, None, None, None, None, None, None, None, None
@staticmethod
def compute_alpha_one_sequence(log_probs, targets, blank=-1):
max_T, max_U, D = log_probs.shape
alpha = np.zeros((max_T, max_U), dtype=np.float32)
for t in range(1, max_T):
alpha[t, 0] = alpha[t - 1, 0] + log_probs[t - 1, 0, blank]
for u in range(1, max_U):
alpha[0, u] = alpha[0, u - 1] + log_probs[0, u - 1, targets[u - 1]]
for t in range(1, max_T):
for u in range(1, max_U):
skip = alpha[t - 1, u] + log_probs[t - 1, u, blank]
emit = alpha[t, u - 1] + log_probs[t, u - 1, targets[u - 1]]
alpha[t, u] = np.logaddexp(skip, emit)
cost = -(alpha[-1, -1] + log_probs[-1, -1, blank])
return alpha, cost
@staticmethod
def compute_beta_one_sequence(log_probs, targets, blank=-1):
max_T, max_U, D = log_probs.shape
beta = np.zeros((max_T, max_U), dtype=np.float32)
beta[-1, -1] = log_probs[-1, -1, blank]
for t in reversed(range(max_T - 1)):
beta[t, -1] = beta[t + 1, -1] + log_probs[t, -1, blank]
for u in reversed(range(max_U - 1)):
beta[-1, u] = beta[-1, u + 1] + log_probs[-1, u, targets[u]]
for t in reversed(range(max_T - 1)):
for u in reversed(range(max_U - 1)):
skip = beta[t + 1, u] + log_probs[t, u, blank]
emit = beta[t, u + 1] + log_probs[t, u, targets[u]]
beta[t, u] = np.logaddexp(skip, emit)
cost = -beta[0, 0]
return beta, cost
@staticmethod
def compute_gradients_one_sequence(
log_probs, alpha, beta, targets, blank=-1
):
max_T, max_U, D = log_probs.shape
gradients = np.full(log_probs.shape, float("-inf"))
cost = -beta[0, 0]
gradients[-1, -1, blank] = alpha[-1, -1]
gradients[:-1, :, blank] = alpha[:-1, :] + beta[1:, :]
for u, l in enumerate(targets):
gradients[:, u, l] = alpha[:, u] + beta[:, u + 1]
gradients = -(np.exp(gradients + log_probs + cost))
return gradients
@staticmethod
def compute(
log_probs,
logit_lengths,
target_lengths,
targets,
blank=-1,
):
gradients = np.zeros_like(log_probs)
B_tgt, max_T, max_U, D = log_probs.shape
B_src = logit_lengths.shape[0]
H = int(B_tgt / B_src)
alphas = np.zeros((B_tgt, max_T, max_U))
betas = np.zeros((B_tgt, max_T, max_U))
betas.fill(float("-inf"))
alphas.fill(float("-inf"))
costs = np.zeros(B_tgt)
for b_tgt in range(B_tgt):
b_src = int(b_tgt / H)
T = int(logit_lengths[b_src])
# NOTE: see https://arxiv.org/pdf/1211.3711.pdf Section 2.1
U = int(target_lengths[b_tgt]) + 1
seq_log_probs = log_probs[b_tgt, :T, :U, :]
seq_targets = targets[b_tgt, : int(target_lengths[b_tgt])]
alpha, alpha_cost = __class__.compute_alpha_one_sequence(
log_probs=seq_log_probs, targets=seq_targets, blank=blank
)
beta, beta_cost = __class__.compute_beta_one_sequence(
log_probs=seq_log_probs, targets=seq_targets, blank=blank
)
seq_gradients = __class__.compute_gradients_one_sequence(
log_probs=seq_log_probs,
alpha=alpha,
beta=beta,
targets=seq_targets,
blank=blank,
)
np.testing.assert_almost_equal(alpha_cost, beta_cost, decimal=2)
gradients[b_tgt, :T, :U, :] = seq_gradients
costs[b_tgt] = beta_cost
alphas[b_tgt, :T, :U] = alpha
betas[b_tgt, :T, :U] = beta
return gradients, costs, alphas, betas
class NumpyTransducerLoss(torch.nn.Module):
def __init__(self, blank=-1):
super().__init__()
self.blank = blank
def forward(
self,
logits,
logit_lengths,
target_lengths,
targets,
):
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
return _NumpyTransducer.apply(
log_probs,
logit_lengths,
target_lengths,
targets,
self.blank,
)
import torch
from torchaudio_unittest import common_utils
from .utils import skipIfNoTransducer
from .rnnt_loss_impl import RNNTLossTest
@skipIfNoTransducer
class TestRNNTLoss(RNNTLossTest, common_utils.PytorchTestCase):
device = torch.device('cpu')
import numpy as np
from torchaudio.prototype.rnnt_loss import RNNTLoss
from .utils import (
compute_with_numpy_transducer,
compute_with_pytorch_transducer,
get_B1_T10_U3_D4_data,
get_data_basic,
get_numpy_data_B1_T2_U3_D5,
get_numpy_data_B2_T4_U3_D3,
get_numpy_random_data,
numpy_to_torch,
)
class RNNTLossTest:
def _test_costs_and_gradients(
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
logits_shape = data["logits"].shape
for reuse_logits_for_grads in [False, True]:
with self.subTest(reuse_logits_for_grads=reuse_logits_for_grads):
costs, gradients = compute_with_pytorch_transducer(
data=data, reuse_logits_for_grads=reuse_logits_for_grads
)
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
for b in range(len(gradients)):
T = data["logit_lengths"][b]
U = data["target_lengths"][b]
for t in range(gradients.shape[1]):
for u in range(gradients.shape[2]):
np.testing.assert_allclose(
gradients[b, t, u],
ref_gradients[b, t, u],
atol=atol,
rtol=rtol,
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
)
def test_basic_backward(self):
rnnt_loss = RNNTLoss()
logits, targets, logit_lengths, target_lengths = get_data_basic(self.device)
loss = rnnt_loss(logits, targets, logit_lengths, target_lengths)
loss.backward()
def test_costs_and_gradients_B1_T2_U3_D5_fp32(self):
data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5(
dtype=np.float32
)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
def test_costs_and_gradients_B1_T2_U3_D5_fp16(self):
data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5(
dtype=np.float16
)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data,
ref_costs=ref_costs,
ref_gradients=ref_gradients,
atol=1e-3,
rtol=1e-2,
)
def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(
dtype=np.float32
)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
def test_costs_and_gradients_B2_T4_U3_D3_fp16(self):
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(
dtype=np.float16
)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data,
ref_costs=ref_costs,
ref_gradients=ref_gradients,
atol=1e-3,
rtol=1e-2,
)
def test_costs_and_gradients_random_data_with_numpy_fp32(self):
seed = 777
for i in range(5):
data = get_numpy_random_data(dtype=np.float32, seed=(seed + i))
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
ref_costs, ref_gradients = compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
def test_rnnt_nonfused_log_softmax(self):
for random in [False, True]:
data = get_B1_T10_U3_D4_data(
random=random,
)
data = numpy_to_torch(
data=data, device=self.device, requires_grad=True
)
data["fused_log_softmax"] = False
ref_costs, ref_gradients = compute_with_numpy_transducer(
data=data
)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
import unittest
import numpy as np
import torch
from torchaudio.prototype.transducer import RNNTLoss
from torchaudio.prototype.rnnt_loss import RNNTLoss
from .numpy_transducer import NumpyTransducerLoss
def compute_with_numpy_transducer(data):
costs = NumpyTransducerLoss(
blank=data["blank"],
)(
logits=data["logits"],
logit_lengths=data["logit_lengths"],
target_lengths=data["target_lengths"],
targets=data["targets"],
)
from torchaudio_unittest.common_utils import TorchaudioTestCase
loss = torch.sum(costs)
loss.backward()
costs = costs.cpu().data.numpy()
gradients = data["logits"].saved_grad.cpu().data.numpy()
return costs, gradients
def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
costs = RNNTLoss(
blank=data["blank"],
fused_log_softmax=data.get("fused_log_softmax", True),
reuse_logits_for_grads=reuse_logits_for_grads,
)(
logits=data["logits"],
logit_lengths=data["logit_lengths"],
target_lengths=data["target_lengths"],
targets=data["targets"],
)
loss = torch.sum(costs)
loss.backward()
costs = costs.cpu().data.numpy()
gradients = data["logits"].saved_grad.cpu().data.numpy()
return costs, gradients
def get_data_basic(device):
......@@ -10,7 +50,7 @@ def get_data_basic(device):
# in 6f73a2513dc784c59eec153a45f40bc528355b18
# of https://github.com/HawkAaron/warp-transducer
acts = torch.tensor(
logits = torch.tensor(
[
[
[
......@@ -27,24 +67,135 @@ def get_data_basic(device):
],
dtype=torch.float,
)
labels = torch.tensor([[1, 2]], dtype=torch.int)
act_length = torch.tensor([2], dtype=torch.int)
label_length = torch.tensor([2], dtype=torch.int)
targets = torch.tensor([[1, 2]], dtype=torch.int)
logit_lengths = torch.tensor([2], dtype=torch.int)
target_lengths = torch.tensor([2], dtype=torch.int)
logits = logits.to(device=device)
targets = targets.to(device=device)
logit_lengths = logit_lengths.to(device=device)
target_lengths = target_lengths.to(device=device)
acts = acts.to(device)
labels = labels.to(device)
act_length = act_length.to(device)
label_length = label_length.to(device)
logits.requires_grad_(True)
return logits, targets, logit_lengths, target_lengths
def get_B1_T10_U3_D4_data(
random=False,
dtype=np.float32,
nan=False,
):
B, T, U, D = 2, 10, 3, 4
data = {}
data["logits"] = np.random.rand(B, T, U, D).astype(dtype)
if not random:
data["logits"].fill(0.1)
if nan:
for i in range(B):
data["logits"][i][0][0][0] = np.nan
data["logit_lengths"] = np.array([10, 10], dtype=np.int32)
data["target_lengths"] = np.array([2, 2], dtype=np.int32)
data["targets"] = np.array([[1, 2], [1, 2]], dtype=np.int32)
data["blank"] = 0
return data
acts.requires_grad_(True)
return acts, labels, act_length, label_length
def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
logits = np.array(
[
0.1,
0.6,
0.1,
0.1,
0.1,
0.1,
0.1,
0.6,
0.1,
0.1,
0.1,
0.1,
0.2,
0.8,
0.1,
0.1,
0.6,
0.1,
0.1,
0.1,
0.1,
0.1,
0.2,
0.1,
0.1,
0.7,
0.1,
0.2,
0.1,
0.1,
],
dtype=dtype,
).reshape(1, 2, 3, 5)
targets = np.array([[1, 2]], dtype=np.int32)
logit_lengths = np.array([2], dtype=np.int32)
target_lengths = np.array([2], dtype=np.int32)
blank = -1
def get_data_B2_T4_U3_D3(dtype=torch.float32, device="cpu"):
ref_costs = np.array([5.09566688538], dtype=dtype)
ref_gradients = np.array(
[
0.17703132,
-0.39992708,
0.17703132,
0.17703132,
-0.13116692,
0.12247062,
0.12247062,
-0.181684,
0.12247062,
-0.1857276,
0.06269141,
0.06269141,
0.06928471,
0.12624498,
-0.32091248,
0.05456069,
-0.2182428,
0.05456069,
0.05456069,
0.05456069,
0.12073967,
0.12073967,
-0.48295838,
0.12073967,
0.12073967,
0.30741188,
0.16871123,
0.18645471,
0.16871123,
-0.83128875,
],
dtype=dtype,
).reshape(1, 2, 3, 5)
data = {
"logits": logits,
"targets": targets,
"logit_lengths": logit_lengths,
"target_lengths": target_lengths,
"blank": blank,
}
return data, ref_costs, ref_gradients
def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
# Test from D21322854
logits = torch.tensor(
logits = np.array(
[
0.065357,
0.787530,
......@@ -122,15 +273,15 @@ def get_data_B2_T4_U3_D3(dtype=torch.float32, device="cpu"):
dtype=dtype,
).reshape(2, 4, 3, 3)
targets = torch.tensor([[1, 2], [1, 1]], dtype=torch.int32)
src_lengths = torch.tensor([4, 4], dtype=torch.int32)
tgt_lengths = torch.tensor([2, 2], dtype=torch.int32)
targets = np.array([[1, 2], [1, 1]], dtype=np.int32)
logit_lengths = np.array([4, 4], dtype=np.int32)
target_lengths = np.array([2, 2], dtype=np.int32)
blank = 0
ref_costs = torch.tensor([4.2806528590890736, 3.9384369822503591], dtype=dtype)
ref_costs = np.array([4.2806528590890736, 3.9384369822503591], dtype=dtype)
ref_gradients = torch.tensor(
ref_gradients = np.array(
[
-0.186844,
-0.062555,
......@@ -208,78 +359,92 @@ def get_data_B2_T4_U3_D3(dtype=torch.float32, device="cpu"):
dtype=dtype,
).reshape(2, 4, 3, 3)
logits.requires_grad_(True)
logits = logits.to(device)
def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
data = {
"logits": logits,
"targets": targets,
"src_lengths": src_lengths,
"tgt_lengths": tgt_lengths,
"logit_lengths": logit_lengths,
"target_lengths": target_lengths,
"blank": blank,
}
return data, ref_costs, ref_gradients
def compute_with_pytorch_transducer(data):
costs = RNNTLoss(blank=data["blank"], reduction="none")(
acts=data["logits"],
labels=data["targets"],
act_lens=data["src_lengths"],
label_lens=data["tgt_lengths"],
)
def get_numpy_random_data(
max_B=8, max_T=128, max_U=32, max_D=40, blank=-1, dtype=np.float32, seed=None
):
if seed is not None:
np.random.seed(seed=seed)
loss = torch.sum(costs)
loss.backward()
costs = costs.cpu()
gradients = data["logits"].saved_grad.cpu()
return costs, gradients
if blank != -1:
raise ValueError("blank != -1 is not supported yet.")
B = np.random.randint(low=1, high=max_B)
T = np.random.randint(low=5, high=max_T)
U = np.random.randint(low=5, high=max_U)
D = np.random.randint(low=2, high=max_D)
def skipIfNoTransducer(test_item):
try:
torch.ops.torchaudio.rnnt_loss
return test_item
except RuntimeError:
return unittest.skip("torchaudio C++ extension is not compiled with RNN transducer loss")(test_item)
logit_lengths = np.random.randint(low=5, high=T + 1, size=(B,), dtype=np.int32)
target_lengths = np.random.randint(low=5, high=U + 1, size=(B,), dtype=np.int32)
max_src_length = np.max(logit_lengths)
max_tgt_length = np.max(target_lengths)
targets = np.random.randint(
low=0, high=D - 1, size=(B, max_tgt_length), dtype=np.int32
)
logits = np.random.random_sample(
size=(B, max_src_length, max_tgt_length + 1, D)
).astype(dtype=dtype)
return {
"logits": logits,
"targets": targets,
"logit_lengths": logit_lengths,
"target_lengths": target_lengths,
"blank": blank,
}
class TransducerTester:
def test_basic_fp16_error(self):
rnnt_loss = RNNTLoss()
acts, labels, act_length, label_length = get_data_basic(self.device)
acts = acts.to(torch.float16)
# RuntimeError raised by log_softmax before reaching transducer's bindings
self.assertRaises(
RuntimeError, rnnt_loss, acts, labels, act_length, label_length
def numpy_to_torch(data, device, requires_grad=True):
logits = torch.from_numpy(data["logits"])
targets = torch.from_numpy(data["targets"])
logit_lengths = torch.from_numpy(data["logit_lengths"])
target_lengths = torch.from_numpy(data["target_lengths"])
if "nbest_wers" in data:
data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device)
if "nbest_scores" in data:
data["nbest_scores"] = torch.from_numpy(data["nbest_scores"]).to(
device=device
)
def test_basic_backward(self):
rnnt_loss = RNNTLoss()
acts, labels, act_length, label_length = get_data_basic(self.device)
loss = rnnt_loss(acts, labels, act_length, label_length)
loss.backward()
logits = torch.autograd.Variable(logits, requires_grad=requires_grad)
logit_lengths = torch.autograd.Variable(logit_lengths)
target_lengths = torch.autograd.Variable(target_lengths)
targets = torch.autograd.Variable(targets)
def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
if device == torch.device("cpu"):
logits = logits.cpu()
elif device == torch.device("cuda"):
logits = logits.cuda()
else:
raise ValueError("unrecognized device = {}".format(device))
data, ref_costs, ref_gradients = get_data_B2_T4_U3_D3(
dtype=torch.float32, device=self.device
)
logits_shape = data["logits"].shape
costs, gradients = compute_with_pytorch_transducer(data=data)
def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
atol, rtol = 1e-6, 1e-2
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
data["logits"] = logits
data["logit_lengths"] = logit_lengths
data["target_lengths"] = target_lengths
data["targets"] = targets
return data
@skipIfNoTransducer
class CPUTransducerTester(TransducerTester, TorchaudioTestCase):
device = "cpu"
def skipIfNoTransducer(test_item):
try:
torch.ops.torchaudio.rnnt_loss
return test_item
except RuntimeError:
return unittest.skip("torchaudio C++ extension is not compiled with RNN transducer loss")
......@@ -21,12 +21,4 @@ if (BUILD_KALDI)
list(APPEND TORCHAUDIO_THIRD_PARTIES kaldi)
endif()
################################################################################
# transducer
################################################################################
if (BUILD_TRANSDUCER)
add_subdirectory(transducer)
list(APPEND TORCHAUDIO_THIRD_PARTIES warprnnt)
endif()
set_property(GLOBAL PROPERTY TORCHAUDIO_THIRD_PARTIES "${TORCHAUDIO_THIRD_PARTIES}")
add_library(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp)
target_compile_definitions(warprnnt PRIVATE RNNT_DISABLE_OMP)
target_include_directories(warprnnt PUBLIC submodule/include)
Subproject commit f546575109111c455354861a0567c8aa794208a2
......@@ -11,7 +11,16 @@ set(
)
if(BUILD_TRANSDUCER)
list(APPEND LIBTORCHAUDIO_SOURCES transducer.cpp)
set(
TRANSDUCER_SOURCES
rnnt/cpu/compute_alphas.cpp
rnnt/cpu/compute_betas.cpp
rnnt/cpu/compute.cpp
rnnt/compute_alphas.cpp
rnnt/compute_betas.cpp
rnnt/compute.cpp
)
list(APPEND LIBTORCHAUDIO_SOURCES ${TRANSDUCER_SOURCES})
endif()
if(BUILD_KALDI)
......
#include <torch/script.h>
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss(Tensor logits,"
"Tensor targets,"
"Tensor src_lengths,"
"Tensor tgt_lengths,"
"int blank,"
"float clamp,"
"bool fused_log_smax=True,"
"bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)");
}
#include <torch/script.h>
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss_alphas(Tensor logits,"
"Tensor targets,"
"Tensor src_lengths,"
"Tensor tgt_lengths,"
"int blank,"
"float clamp) -> Tensor");
}
#include <torch/script.h>
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss_betas(Tensor logits,"
"Tensor targets,"
"Tensor src_lengths,"
"Tensor tgt_lengths,"
"int blank,"
"float clamp) -> Tensor");
}
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace cpu {
// Entry point into RNNT Loss
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_smax;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) {
if (reuse_logits_for_grads) {
gradients = logits;
} else {
gradients = torch::zeros_like(logits);
}
}
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
switch (logits.scalar_type()) {
case torch::ScalarType::Float: {
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
break;
}
case torch::ScalarType::Half: {
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr
: gradients->data_ptr<c10::Half>());
break;
}
default: {
break;
}
};
return std::make_tuple(costs, gradients);
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss", &compute);
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace cpu {
torch::Tensor compute_alphas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
torch::Tensor alphas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>());
return alphas;
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss_alphas", &compute_alphas);
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace cpu {
torch::Tensor compute_betas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
torch::Tensor costs = torch::empty(
tgt_lengths.size(0),
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor betas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*betas=*/betas.data_ptr<float>());
return betas;
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss_betas", &compute_betas);
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/rnnt/cpu/math.h>
#include <torchaudio/csrc/rnnt/options.h>
#include <torchaudio/csrc/rnnt/types.h>
#include <cstring>
#include <limits>
#include <vector>
namespace torchaudio {
namespace rnnt {
namespace cpu {
template <typename DTYPE>
struct LogProbs {
DTYPE skip_; // blank.
DTYPE emit_; // target.
LogProbs(DTYPE skip, DTYPE emit) : skip_(skip), emit_(emit) {}
DTYPE& skip() {
return skip_;
}
DTYPE& emit() {
return emit_;
}
const DTYPE& skip() const {
return skip_;
}
const DTYPE& emit() const {
return emit_;
}
};
// TensorView: view a block of allocated memory as a tensor.
template <typename DTYPE>
class TensorView {
public:
TensorView(const std::vector<int>& dims, DTYPE* data)
: dims_(dims), data_(data) {
strides_.resize(dims.size());
strides_.back() = 1;
for (int i = dims.size() - 2; i >= 0; --i) {
strides_[i] = strides_[i + 1] * dims[i + 1];
}
}
DTYPE& operator()(const std::vector<int>& indices) {
CHECK_EQ(indices.size(), dims_.size());
int index = indices.back();
for (int i = indices.size() - 2; i >= 0; --i) {
index += indices[i] * strides_[i];
}
return data_[index];
}
void SetZero() {
int size = dims_[0] * strides_[0];
std::memset(data_, 0, sizeof(DTYPE) * size);
}
private:
std::vector<int> dims_;
std::vector<int> strides_;
DTYPE* data_;
};
template <typename DTYPE, typename CAST_DTYPE>
status_t LogSumExp2D(int N, int D, const DTYPE* logits, CAST_DTYPE* outputs) {
for (int i = 0; i < N * D; i += D) {
CAST_DTYPE max = logits[i];
for (int j = 1; j < D; ++j) {
max = std::max(max, CAST_DTYPE(logits[i + j]));
}
CAST_DTYPE sum = 0;
for (int j = 0; j < D; ++j) {
sum = sum + std::exp(CAST_DTYPE(logits[i + j]) - max);
}
outputs[i / D] = max + std::log(sum);
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeLogProbsOneSequence(
const Options& options,
TensorView<const DTYPE>& logits,
const int* targets,
int srcLen,
int tgtLen,
TensorView<const CAST_DTYPE>& denom,
TensorView<LogProbs<CAST_DTYPE>>& logProbs) {
const int& T = srcLen;
const int& U = tgtLen;
const int& blank = options.blank_;
for (int t = 0; t < T; ++t) {
for (int u = 0; u < U; ++u) {
if (u < U - 1) {
logProbs({t, u}).emit() =
CAST_DTYPE(logits({t, u, targets[u]})) - denom({t, u});
}
logProbs({t, u}).skip() =
CAST_DTYPE(logits({t, u, blank})) - denom({t, u});
}
}
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeLogProbs(
const Options& options,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs) {
std::vector<TensorView<const DTYPE>> seqLogits;
std::vector<const int*> seqTargets;
std::vector<TensorView<const CAST_DTYPE>> seqDenoms;
std::vector<TensorView<LogProbs<CAST_DTYPE>>> seqlogProbs;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
for (int b = 0; b < B; ++b) {
seqLogits.push_back(
TensorView<const DTYPE>({maxT, maxU, D}, logits + b * maxT * maxU * D));
seqTargets.push_back(targets + b * (maxU - 1));
seqDenoms.push_back(TensorView<const CAST_DTYPE>(
{maxT, maxU}, denominators + b * maxT * maxU));
seqlogProbs.push_back(TensorView<LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(logProbs) + b * maxT * maxU));
}
//#pragma omp parallel for
for (int b = 0; b < B; ++b) { // use max 2 * B threads.
ComputeLogProbsOneSequence<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/seqLogits[b],
/*targets=*/seqTargets[b],
/*srcLen=*/srcLengths[b],
/*tgtLen=*/tgtLengths[b] + 1, // with prepended blank.
/*denom=*/seqDenoms[b],
/*logProbs=*/seqlogProbs[b]);
}
return SUCCESS;
}
template <typename DTYPE>
DTYPE ComputeAlphaOneSequence(
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen,
int tgtLen,
TensorView<DTYPE>& alpha) {
const int& T = srcLen;
const int& U = tgtLen;
alpha({0, 0}) = DTYPE(0);
for (int t = 1; t < T; ++t) { // u == 0.
alpha({t, 0}) = alpha({t - 1, 0}) + logProbs({t - 1, 0}).skip();
}
for (int u = 1; u < U; ++u) { // t == 0.
alpha({0, u}) = alpha({0, u - 1}) + logProbs({0, u - 1}).emit();
}
for (int t = 1; t < T; ++t) {
for (int u = 1; u < U; ++u) {
alpha({t, u}) = math::lse(
alpha({t - 1, u}) + logProbs({t - 1, u}).skip(),
alpha({t, u - 1}) + logProbs({t, u - 1}).emit());
}
}
DTYPE forward_score = alpha({T - 1, U - 1}) + logProbs({T - 1, U - 1}).skip();
return forward_score;
}
template <typename DTYPE>
DTYPE ComputeBetaOneSequence(
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen,
int tgtLen,
TensorView<DTYPE>& beta) {
const int& T = srcLen;
const int& U = tgtLen;
beta({T - 1, U - 1}) = logProbs({T - 1, U - 1}).skip();
for (int t = T - 2; t >= 0; --t) { // u == U - 1.
beta({t, U - 1}) = beta({t + 1, U - 1}) + logProbs({t, U - 1}).skip();
}
for (int u = U - 2; u >= 0; --u) { // t == T - 1.
beta({T - 1, u}) = beta({T - 1, u + 1}) + logProbs({T - 1, u}).emit();
}
for (int t = T - 2; t >= 0; --t) {
for (int u = U - 2; u >= 0; --u) {
beta({t, u}) = math::lse(
beta({t + 1, u}) + logProbs({t, u}).skip(),
beta({t, u + 1}) + logProbs({t, u}).emit());
}
}
DTYPE backward_score = beta({0, 0});
return backward_score;
}
template <typename DTYPE>
DTYPE ComputeAlphaOrBetaOneSequence(
int thread,
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen,
int tgtLen,
TensorView<DTYPE>& alpha,
TensorView<DTYPE>& beta) {
if (thread & 1) {
return ComputeAlphaOneSequence<DTYPE>(
/*options=*/options,
/*logProbs=*/logProbs,
/*srcLen=*/srcLen,
/*tgtLen=*/tgtLen,
/*alpha=*/alpha);
} else {
return ComputeBetaOneSequence<DTYPE>(
/*options=*/options,
/*logProbs=*/logProbs,
/*srcLen=*/srcLen,
/*tgtLen=*/tgtLen,
/*beta=*/beta);
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeAlphasBetas(
const Options& options,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
CAST_DTYPE* alphas,
CAST_DTYPE* betas,
DTYPE* costs) {
std::vector<TensorView<const LogProbs<CAST_DTYPE>>> seqlogProbs;
std::vector<TensorView<CAST_DTYPE>> seq_alphas;
std::vector<TensorView<CAST_DTYPE>> seq_betas;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int b = 0; b < B; ++b) {
seqlogProbs.push_back(TensorView<const LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(
const_cast<CAST_DTYPE*>(logProbs)) +
b * maxT * maxU));
seq_alphas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, alphas + b * maxT * maxU));
seq_betas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, betas + b * maxT * maxU));
}
std::vector<CAST_DTYPE> scores(B << 1);
//#pragma omp parallel for
for (int t = 0; t < (B << 1); ++t) { // use max 2 * B threads.
int i = (t >> 1);
scores[t] = ComputeAlphaOrBetaOneSequence<CAST_DTYPE>(
/*thread=*/t,
/*options=*/options,
/*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
/*alpha=*/seq_alphas[i],
/*beta=*/seq_betas[i]);
}
for (int b = 0; b < B; ++b) {
costs[b] = -scores[b << 1];
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeGradientsOneSequence(
const Options& options,
TensorView<const DTYPE>& logits,
const int* targets,
int srcLen,
int tgtLen,
TensorView<const CAST_DTYPE>& denom,
TensorView<const CAST_DTYPE>& alpha,
TensorView<const CAST_DTYPE>& beta,
TensorView<DTYPE>& gradients) {
// don't set gradients to zero to here as gradients might reuse memory from
// logits
const int& T = srcLen;
const int& U = tgtLen;
const int& D = options.numTargets_;
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
CAST_DTYPE cost = -beta({0, 0});
// Note - below gradient is different from numpy_transducer, since we
// compute log_softmax more efficiently within the loss, to save memory The
// details of the below implementation / equations can be found in Sec 3.2
// (function merging) in below paper:
// https://www.microsoft.com/en-us/research/uploads/prod/2019/10/RNNT.pdf
for (int t = 0; t < T; ++t) {
for (int u = 0; u < U; ++u) {
CAST_DTYPE c = alpha({t, u}) + cost - denom({t, u});
for (int d = 0; d < D; ++d) {
CAST_DTYPE g = CAST_DTYPE(logits({t, u, d})) + c;
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g);
} else if (d == blank && t < T - 1) {
gradients({t, u, d}) =
std::exp(g + beta({t, u})) - std::exp(g + beta({t + 1, u}));
} else if (u < U - 1 && d == targets[u]) {
gradients({t, u, d}) =
std::exp(g + beta({t, u})) - std::exp(g + beta({t, u + 1}));
} else {
gradients({t, u, d}) = std::exp(g + beta({t, u}));
}
if (clamp > 0) {
gradients({t, u, d}) =
math::min(CAST_DTYPE(gradients({t, u, d})), clamp);
gradients({t, u, d}) =
math::max(CAST_DTYPE(gradients({t, u, d})), -clamp);
}
}
}
}
// zero out the rest of the gradients, necessary when reusing logits memory
// check the memory location to see if it's necessary
if (&gradients({0, 0, 0}) == &logits({0, 0, 0})) {
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int t = T; t < maxT; ++t) {
for (int u = 0; u < maxU; ++u) {
for (int d = 0; d < D; ++d) {
gradients({t, u, d}) = 0.;
}
}
}
for (int t = 0; t < T; ++t) {
for (int u = U; u < maxU; ++u) {
for (int d = 0; d < D; ++d) {
gradients({t, u, d}) = 0.;
}
}
}
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeGradients(
const Options& options,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients) {
std::vector<TensorView<const DTYPE>> seqLogits;
std::vector<const int*> seqTargets;
std::vector<TensorView<const CAST_DTYPE>> seqDenoms;
std::vector<TensorView<const CAST_DTYPE>> seq_alphas;
std::vector<TensorView<const CAST_DTYPE>> seq_betas;
std::vector<TensorView<DTYPE>> seq_gradients;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
for (int b = 0; b < B; ++b) {
seqLogits.push_back(
TensorView<const DTYPE>({maxT, maxU, D}, logits + b * maxT * maxU * D));
seqTargets.push_back(targets + b * (maxU - 1));
seqDenoms.push_back(TensorView<const CAST_DTYPE>(
{maxT, maxU}, denominators + b * maxT * maxU));
seq_alphas.push_back(
TensorView<const CAST_DTYPE>({maxT, maxU}, alphas + b * maxT * maxU));
seq_betas.push_back(
TensorView<const CAST_DTYPE>({maxT, maxU}, betas + b * maxT * maxU));
seq_gradients.push_back(
TensorView<DTYPE>({maxT, maxU, D}, gradients + b * maxT * maxU * D));
}
//#pragma omp parallel for
for (int b = 0; b < B; ++b) { // use max 2 * B threads.
ComputeGradientsOneSequence<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/seqLogits[b],
/*targets=*/seqTargets[b],
/*srcLen=*/srcLengths[b],
/*tgtLen=*/tgtLengths[b] + 1, // with prepended blank.
/*denom=*/seqDenoms[b],
/*alpha=*/seq_alphas[b],
/*beta=*/seq_betas[b],
/*gradients=*/seq_gradients[b]);
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeAlphas(
const Options& options,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
CAST_DTYPE* alphas) {
std::vector<TensorView<const LogProbs<CAST_DTYPE>>> seqlogProbs;
std::vector<TensorView<CAST_DTYPE>> seq_alphas;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int b = 0; b < B; ++b) {
seqlogProbs.push_back(TensorView<const LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(
const_cast<CAST_DTYPE*>(logProbs)) +
b * maxT * maxU));
seq_alphas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, alphas + b * maxT * maxU));
}
std::vector<CAST_DTYPE> scores(B << 1);
//#pragma omp parallel for
for (int i = 0; i < B; ++i) { // use max 2 * B threads.
ComputeAlphaOneSequence<DTYPE>(
options,
/*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
/*alpha=*/seq_alphas[i]);
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeBetas(
const Options& options,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
CAST_DTYPE* costs,
CAST_DTYPE* betas) {
std::vector<TensorView<const LogProbs<CAST_DTYPE>>> seqlogProbs;
std::vector<TensorView<CAST_DTYPE>> seq_betas;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int b = 0; b < B; ++b) {
seqlogProbs.push_back(TensorView<const LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(
const_cast<CAST_DTYPE*>(logProbs)) +
b * maxT * maxU));
seq_betas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, betas + b * maxT * maxU));
}
std::vector<CAST_DTYPE> scores(B << 1);
//#pragma omp parallel for
for (int i = 0; i < B; ++i) { // use max 2 * B threads.
ComputeBetaOneSequence<DTYPE>(
options,
/*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
/*betas=*/seq_betas[i]);
}
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment