Unverified Commit 7aac931e authored by jieruan's avatar jieruan Committed by GitHub
Browse files

Fix transducer test skip logic (#1253)

parent 7ee1c46b
import unittest
import torch import torch
from torchaudio.prototype.transducer import RNNTLoss from torchaudio.prototype.transducer import RNNTLoss
from torchaudio_unittest import common_utils from torchaudio_unittest.common_utils import TorchaudioTestCase
def get_data_basic(device): def get_data_basic(device):
...@@ -241,6 +242,14 @@ def compute_with_pytorch_transducer(data): ...@@ -241,6 +242,14 @@ def compute_with_pytorch_transducer(data):
return costs, gradients return costs, gradients
def skipIfNoTransducer(test_item):
try:
torch.ops.torchaudio.rnnt_loss
return test_item
except RuntimeError:
return unittest.skip("torchaudio C++ extension is not compiled with RNN transducer loss")(test_item)
class TransducerTester: class TransducerTester:
def test_basic_fp16_error(self): def test_basic_fp16_error(self):
rnnt_loss = RNNTLoss() rnnt_loss = RNNTLoss()
...@@ -271,6 +280,6 @@ class TransducerTester: ...@@ -271,6 +280,6 @@ class TransducerTester:
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol) self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
@common_utils.skipIfNoExtension @skipIfNoTransducer
class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase): class CPUTransducerTester(TransducerTester, TorchaudioTestCase):
device = "cpu" device = "cpu"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment