Unverified Commit d4d09074 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Add autograd gradcheck test for RNN transducer loss (#1532)

* autograd test from carolineechen/audio#2

* fix numpy backward: be careful to not modify inplace.
parent 15a7f78c
import torch
from .autograd_impl import Autograd
from torchaudio_unittest import common_utils
from .utils import skipIfNoTransducer
@skipIfNoTransducer
class TestAutograd(Autograd, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
import torch
from .autograd_impl import Autograd
from torchaudio_unittest import common_utils
from .utils import skipIfNoTransducer
@skipIfNoTransducer
@common_utils.skipIfNoCuda
class TestAutograd(Autograd, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
from typing import Callable, Tuple
import torch
from torch import Tensor
from torch.autograd import gradcheck
from torchaudio_unittest.common_utils import (
TestBaseMixin,
)
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss
from parameterized import parameterized
from .utils import (
numpy_to_torch,
get_B1_T10_U3_D4_data,
get_numpy_data_B2_T4_U3_D3,
get_numpy_data_B1_T2_U3_D5
)
from .numpy_transducer import NumpyTransducerLoss
class Autograd(TestBaseMixin):
@staticmethod
def get_data(data_func, device):
data_np = data_func()
if type(data_np) == tuple:
data_np = data_np[0]
data = numpy_to_torch(
data=data_np, device=device, requires_grad=True
)
return data
def assert_grad(
self,
loss: Callable[..., Tensor],
inputs: Tuple[torch.Tensor],
*,
enable_all_grad: bool = True,
):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.dtype, device=self.device)
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
# gradcheck with float32 requires higher atol and epsilon
assert gradcheck(loss, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.)
@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_numpy_data_B2_T4_U3_D3, ),
(get_numpy_data_B1_T2_U3_D5, ),
])
def test_RNNTLoss_gradcheck(self, data_func):
data = self.get_data(data_func, self.device)
inputs = (
data["logits"].to(self.dtype),
data["targets"],
data["logit_lengths"],
data["target_lengths"],
)
loss = RNNTLoss(blank=data["blank"], reuse_logits_for_grads=False)
self.assert_grad(loss, inputs, enable_all_grad=False)
@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_numpy_data_B2_T4_U3_D3, ),
(get_numpy_data_B1_T2_U3_D5, ),
])
def test_rnnt_loss_gradcheck(self, data_func):
data = self.get_data(data_func, self.device)
inputs = (
data["logits"].to(self.dtype), # logits
data["targets"], # targets
data["logit_lengths"], # logit_lengths
data["target_lengths"], # target_lengths
data["blank"], # blank
-1, # clamp
True, # fused_log_softmax
False, # reuse_logits_for_grads
)
self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)
@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_numpy_data_B2_T4_U3_D3, ),
(get_numpy_data_B1_T2_U3_D5, ),
])
def test_np_transducer_gradcheck(self, data_func):
data = self.get_data(data_func, self.device)
inputs = (
data["logits"].to(self.dtype),
data["logit_lengths"],
data["target_lengths"],
data["targets"],
)
loss = NumpyTransducerLoss(blank=data["blank"])
self.assert_grad(loss, inputs, enable_all_grad=False)
......@@ -33,8 +33,9 @@ class _NumpyTransducer(torch.autograd.Function):
return costs
@staticmethod
def backward(ctx, output_gradients):
return ctx.grads, None, None, None, None, None, None, None, None
def backward(ctx, grad_output):
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul(grad_output), None, None, None, None, None, None, None, None
@staticmethod
def compute_alpha_one_sequence(log_probs, targets, blank=-1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment