Unverified Commit af7eb4d6 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Add torchscript support to RNNT Loss (#1507)

parent 079b3f5d
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .utils import skipIfNoTransducer
from .torchscript_consistency_impl import RNNTLossTorchscript
@skipIfNoTransducer
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cpu')
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .utils import skipIfNoTransducer
from .torchscript_consistency_impl import RNNTLossTorchscript
@skipIfNoTransducer
@skipIfNoCuda
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cuda')
import torch
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss
class RNNTLossTorchscript(TempDirMixin, TestBaseMixin):
"""Implements test for RNNT Loss that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor)
self.assertEqual(ts_output, output)
def test_rnnt_loss(self):
def func(
logits,
):
targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
return rnnt_loss(logits, targets, logit_lengths, target_lengths)
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
self._assert_consistency(func, logits)
def test_RNNTLoss(self):
func = RNNTLoss()
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
targets = torch.tensor([[1, 2]], device=self.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)
tensor = logits.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor, targets, logit_lengths, target_lengths)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths)
self.assertEqual(ts_output, output)
......@@ -405,10 +405,10 @@ def get_numpy_random_data(
def numpy_to_torch(data, device, requires_grad=True):
logits = torch.from_numpy(data["logits"])
targets = torch.from_numpy(data["targets"])
logit_lengths = torch.from_numpy(data["logit_lengths"])
target_lengths = torch.from_numpy(data["target_lengths"])
logits = torch.from_numpy(data["logits"]).to(device=device)
targets = torch.from_numpy(data["targets"]).to(device=device)
logit_lengths = torch.from_numpy(data["logit_lengths"]).to(device=device)
target_lengths = torch.from_numpy(data["target_lengths"]).to(device=device)
if "nbest_wers" in data:
data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device)
......
......@@ -19,6 +19,7 @@ if(BUILD_TRANSDUCER)
rnnt/compute_alphas.cpp
rnnt/compute_betas.cpp
rnnt/compute.cpp
rnnt/autograd.cpp
)
if (USE_CUDA)
......
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
namespace torchaudio {
namespace rnnt {
class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
at::AutoNonVariableTypeMode g;
torch::Tensor undef;
auto result = rnnt_loss(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
return {costs, grads};
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto grad = saved[0];
auto grad_out = grad_outputs[0].view({-1, 1, 1, 1});
auto result = grad * grad_out;
torch::Tensor undef;
return {result, undef, undef, undef, undef, undef, undef, undef};
}
};
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
auto results = RNNTLossFunction::apply(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
return std::make_tuple(results[0], results[1]);
}
TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
m.impl("rnnt_loss", rnnt_loss_autograd);
}
} // namespace rnnt
} // namespace torchaudio
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
......
#pragma once
#include <torch/script.h>
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax,
bool reuse_logits_for_grads);
import torch
from torch import Tensor
__all__ = [
"RNNTLoss",
......@@ -19,15 +20,6 @@ def _rnnt_loss_alphas(
See documentation for RNNTLoss
"""
targets = targets.to(device=logits.device)
logit_lengths = logit_lengths.to(device=logits.device)
target_lengths = target_lengths.to(device=logits.device)
# make sure all int tensors are of type int32.
targets = targets.int()
logit_lengths = logit_lengths.int()
target_lengths = target_lengths.int()
return torch.ops.torchaudio.rnnt_loss_alphas(
logits,
targets,
......@@ -51,15 +43,6 @@ def _rnnt_loss_betas(
See documentation for RNNTLoss
"""
targets = targets.to(device=logits.device)
logit_lengths = logit_lengths.to(device=logits.device)
target_lengths = target_lengths.to(device=logits.device)
# make sure all int tensors are of type int32.
targets = targets.int()
logit_lengths = logit_lengths.int()
target_lengths = target_lengths.int()
return torch.ops.torchaudio.rnnt_loss_betas(
logits,
targets,
......@@ -70,77 +53,15 @@ def _rnnt_loss_betas(
)
class _RNNT(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
targets,
logit_lengths,
target_lengths,
blank=-1,
clamp=-1,
fused_log_softmax=True,
reuse_logits_for_grads=True,
):
"""
See documentation for RNNTLoss
"""
# move everything to the same device.
targets = targets.to(device=logits.device)
logit_lengths = logit_lengths.to(device=logits.device)
target_lengths = target_lengths.to(device=logits.device)
# make sure all int tensors are of type int32.
targets = targets.int()
logit_lengths = logit_lengths.int()
target_lengths = target_lengths.int()
if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank
costs, gradients = torch.ops.torchaudio.rnnt_loss(
logits=logits,
targets=targets,
src_lengths=logit_lengths,
tgt_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_smax=fused_log_softmax,
reuse_logits_for_grads=reuse_logits_for_grads,
)
ctx.grads = gradients
return costs
@staticmethod
def backward(ctx, output_gradients):
output_gradients = output_gradients.view(-1, 1, 1, 1).to(ctx.grads)
ctx.grads.mul_(output_gradients).to(ctx.grads)
return (
ctx.grads, # logits
None, # targets
None, # logit_lengths
None, # target_lengths
None, # blank
None, # clamp
None, # fused_log_softmax
None, # reuse_logits_for_grads
)
def rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
blank=-1,
clamp=-1,
fused_log_softmax=True,
reuse_logits_for_grads=True,
logits: Tensor,
targets: Tensor,
logit_lengths: Tensor,
target_lengths: Tensor,
blank: int = -1,
clamp: float = -1,
fused_log_softmax: bool = True,
reuse_logits_for_grads: bool = True,
):
"""
Compute the RNN Transducer Loss.
......@@ -166,17 +87,20 @@ def rnnt_loss(
False # softmax needs the original logits value
)
cost = _RNNT.apply(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax,
reuse_logits_for_grads,
)
return cost
if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank
costs, gradients = torch.ops.torchaudio.rnnt_loss(
logits=logits,
targets=targets,
src_lengths=logit_lengths,
tgt_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_smax=fused_log_softmax,
reuse_logits_for_grads=reuse_logits_for_grads,)
return costs
class RNNTLoss(torch.nn.Module):
......@@ -196,10 +120,10 @@ class RNNTLoss(torch.nn.Module):
def __init__(
self,
blank=-1,
clamp=-1,
fused_log_softmax=True,
reuse_logits_for_grads=True,
blank: int = -1,
clamp: float = -1.,
fused_log_softmax: bool = True,
reuse_logits_for_grads: bool = True,
):
super().__init__()
self.blank = blank
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment