Unverified Commit f37d37d6 authored by moto's avatar moto Committed by GitHub
Browse files

Fix GPU test skip logic (#516)

parent 5f5df1d6
...@@ -33,12 +33,9 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -33,12 +33,9 @@ class TestFunctionalFiltering(unittest.TestCase):
def test_lfilter_basic_double(self): def test_lfilter_basic_double(self):
self._test_lfilter_basic(torch.float64, torch.device("cpu")) self._test_lfilter_basic(torch.float64, torch.device("cpu"))
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_lfilter_basic_gpu(self): def test_lfilter_basic_gpu(self):
if torch.cuda.is_available(): self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))
self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))
else:
print("skipping GPU test for lfilter_basic because device not available")
pass
def _test_lfilter(self, waveform, device): def _test_lfilter(self, waveform, device):
""" """
...@@ -87,16 +84,13 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -87,16 +84,13 @@ class TestFunctionalFiltering(unittest.TestCase):
self._test_lfilter(waveform, torch.device("cpu")) self._test_lfilter(waveform, torch.device("cpu"))
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_lfilter_gpu(self): def test_lfilter_gpu(self):
if torch.cuda.is_available(): filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") waveform, _ = torchaudio.load(filepath, normalization=True)
waveform, _ = torchaudio.load(filepath, normalization=True) cuda0 = torch.device("cuda:0")
cuda0 = torch.device("cuda:0") cuda_waveform = waveform.cuda(device=cuda0)
cuda_waveform = waveform.cuda(device=cuda0) self._test_lfilter(cuda_waveform, cuda0)
self._test_lfilter(cuda_waveform, cuda0)
else:
print("skipping GPU test for lfilter because device not available")
pass
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment