Unverified Commit 7aac931e authored by jieruan's avatar jieruan Committed by GitHub
Browse files

Fix transducer test skip logic (#1253)

parent 7ee1c46b
import unittest
import torch
from torchaudio.prototype.transducer import RNNTLoss
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import TorchaudioTestCase
def get_data_basic(device):
......@@ -241,6 +242,14 @@ def compute_with_pytorch_transducer(data):
return costs, gradients
def skipIfNoTransducer(test_item):
try:
torch.ops.torchaudio.rnnt_loss
return test_item
except RuntimeError:
return unittest.skip("torchaudio C++ extension is not compiled with RNN transducer loss")(test_item)
class TransducerTester:
def test_basic_fp16_error(self):
rnnt_loss = RNNTLoss()
......@@ -271,6 +280,6 @@ class TransducerTester:
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
@common_utils.skipIfNoExtension
class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase):
@skipIfNoTransducer
class CPUTransducerTester(TransducerTester, TorchaudioTestCase):
device = "cpu"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment