Unverified Commit af7eb4d6 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Add torchscript support to RNNT Loss (#1507)

parent 079b3f5d
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .utils import skipIfNoTransducer
from .torchscript_consistency_impl import RNNTLossTorchscript
@skipIfNoTransducer
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cpu')
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .utils import skipIfNoTransducer
from .torchscript_consistency_impl import RNNTLossTorchscript
@skipIfNoTransducer
@skipIfNoCuda
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cuda')
import torch
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss
class RNNTLossTorchscript(TempDirMixin, TestBaseMixin):
"""Implements test for RNNT Loss that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor)
self.assertEqual(ts_output, output)
def test_rnnt_loss(self):
def func(
logits,
):
targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
return rnnt_loss(logits, targets, logit_lengths, target_lengths)
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
self._assert_consistency(func, logits)
def test_RNNTLoss(self):
func = RNNTLoss()
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
targets = torch.tensor([[1, 2]], device=self.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)
tensor = logits.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor, targets, logit_lengths, target_lengths)
torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths)
self.assertEqual(ts_output, output)
...@@ -405,10 +405,10 @@ def get_numpy_random_data( ...@@ -405,10 +405,10 @@ def get_numpy_random_data(
def numpy_to_torch(data, device, requires_grad=True): def numpy_to_torch(data, device, requires_grad=True):
logits = torch.from_numpy(data["logits"]) logits = torch.from_numpy(data["logits"]).to(device=device)
targets = torch.from_numpy(data["targets"]) targets = torch.from_numpy(data["targets"]).to(device=device)
logit_lengths = torch.from_numpy(data["logit_lengths"]) logit_lengths = torch.from_numpy(data["logit_lengths"]).to(device=device)
target_lengths = torch.from_numpy(data["target_lengths"]) target_lengths = torch.from_numpy(data["target_lengths"]).to(device=device)
if "nbest_wers" in data: if "nbest_wers" in data:
data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device) data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device)
......
...@@ -19,6 +19,7 @@ if(BUILD_TRANSDUCER) ...@@ -19,6 +19,7 @@ if(BUILD_TRANSDUCER)
rnnt/compute_alphas.cpp rnnt/compute_alphas.cpp
rnnt/compute_betas.cpp rnnt/compute_betas.cpp
rnnt/compute.cpp rnnt/compute.cpp
rnnt/autograd.cpp
) )
if (USE_CUDA) if (USE_CUDA)
......
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
namespace torchaudio {
namespace rnnt {
class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
at::AutoNonVariableTypeMode g;
torch::Tensor undef;
auto result = rnnt_loss(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
return {costs, grads};
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto grad = saved[0];
auto grad_out = grad_outputs[0].view({-1, 1, 1, 1});
auto result = grad * grad_out;
torch::Tensor undef;
return {result, undef, undef, undef, undef, undef, undef, undef};
}
};
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
auto results = RNNTLossFunction::apply(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
return std::make_tuple(results[0], results[1]);
}
TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
m.impl("rnnt_loss", rnnt_loss_autograd);
}
} // namespace rnnt
} // namespace torchaudio
#include <torch/script.h> #include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def( m.def(
......
#pragma once
#include <torch/script.h>
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax,
bool reuse_logits_for_grads);
import torch import torch
from torch import Tensor
__all__ = [ __all__ = [
"RNNTLoss", "RNNTLoss",
...@@ -19,15 +20,6 @@ def _rnnt_loss_alphas( ...@@ -19,15 +20,6 @@ def _rnnt_loss_alphas(
See documentation for RNNTLoss See documentation for RNNTLoss
""" """
targets = targets.to(device=logits.device)
logit_lengths = logit_lengths.to(device=logits.device)
target_lengths = target_lengths.to(device=logits.device)
# make sure all int tensors are of type int32.
targets = targets.int()
logit_lengths = logit_lengths.int()
target_lengths = target_lengths.int()
return torch.ops.torchaudio.rnnt_loss_alphas( return torch.ops.torchaudio.rnnt_loss_alphas(
logits, logits,
targets, targets,
...@@ -51,15 +43,6 @@ def _rnnt_loss_betas( ...@@ -51,15 +43,6 @@ def _rnnt_loss_betas(
See documentation for RNNTLoss See documentation for RNNTLoss
""" """
targets = targets.to(device=logits.device)
logit_lengths = logit_lengths.to(device=logits.device)
target_lengths = target_lengths.to(device=logits.device)
# make sure all int tensors are of type int32.
targets = targets.int()
logit_lengths = logit_lengths.int()
target_lengths = target_lengths.int()
return torch.ops.torchaudio.rnnt_loss_betas( return torch.ops.torchaudio.rnnt_loss_betas(
logits, logits,
targets, targets,
...@@ -70,77 +53,15 @@ def _rnnt_loss_betas( ...@@ -70,77 +53,15 @@ def _rnnt_loss_betas(
) )
class _RNNT(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
targets,
logit_lengths,
target_lengths,
blank=-1,
clamp=-1,
fused_log_softmax=True,
reuse_logits_for_grads=True,
):
"""
See documentation for RNNTLoss
"""
# move everything to the same device.
targets = targets.to(device=logits.device)
logit_lengths = logit_lengths.to(device=logits.device)
target_lengths = target_lengths.to(device=logits.device)
# make sure all int tensors are of type int32.
targets = targets.int()
logit_lengths = logit_lengths.int()
target_lengths = target_lengths.int()
if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank
costs, gradients = torch.ops.torchaudio.rnnt_loss(
logits=logits,
targets=targets,
src_lengths=logit_lengths,
tgt_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_smax=fused_log_softmax,
reuse_logits_for_grads=reuse_logits_for_grads,
)
ctx.grads = gradients
return costs
@staticmethod
def backward(ctx, output_gradients):
output_gradients = output_gradients.view(-1, 1, 1, 1).to(ctx.grads)
ctx.grads.mul_(output_gradients).to(ctx.grads)
return (
ctx.grads, # logits
None, # targets
None, # logit_lengths
None, # target_lengths
None, # blank
None, # clamp
None, # fused_log_softmax
None, # reuse_logits_for_grads
)
def rnnt_loss( def rnnt_loss(
logits, logits: Tensor,
targets, targets: Tensor,
logit_lengths, logit_lengths: Tensor,
target_lengths, target_lengths: Tensor,
blank=-1, blank: int = -1,
clamp=-1, clamp: float = -1,
fused_log_softmax=True, fused_log_softmax: bool = True,
reuse_logits_for_grads=True, reuse_logits_for_grads: bool = True,
): ):
""" """
Compute the RNN Transducer Loss. Compute the RNN Transducer Loss.
...@@ -166,17 +87,20 @@ def rnnt_loss( ...@@ -166,17 +87,20 @@ def rnnt_loss(
False # softmax needs the original logits value False # softmax needs the original logits value
) )
cost = _RNNT.apply( if blank < 0: # reinterpret blank index if blank < 0.
logits, blank = logits.shape[-1] + blank
targets,
logit_lengths, costs, gradients = torch.ops.torchaudio.rnnt_loss(
target_lengths, logits=logits,
blank, targets=targets,
clamp, src_lengths=logit_lengths,
fused_log_softmax, tgt_lengths=target_lengths,
reuse_logits_for_grads, blank=blank,
) clamp=clamp,
return cost fused_log_smax=fused_log_softmax,
reuse_logits_for_grads=reuse_logits_for_grads,)
return costs
class RNNTLoss(torch.nn.Module): class RNNTLoss(torch.nn.Module):
...@@ -196,10 +120,10 @@ class RNNTLoss(torch.nn.Module): ...@@ -196,10 +120,10 @@ class RNNTLoss(torch.nn.Module):
def __init__( def __init__(
self, self,
blank=-1, blank: int = -1,
clamp=-1, clamp: float = -1.,
fused_log_softmax=True, fused_log_softmax: bool = True,
reuse_logits_for_grads=True, reuse_logits_for_grads: bool = True,
): ):
super().__init__() super().__init__()
self.blank = blank self.blank = blank
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment