Unverified Commit 2c28b743 authored by moto's avatar moto Committed by GitHub
Browse files

Adopt PyTorch's test util to torchscript test (#640)

parent 995b75f8
import os
import tempfile
import unittest
from typing import Type, Iterable
from contextlib import contextmanager
from shutil import copytree
import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import pytest
_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio._backend._audio_backends
......@@ -87,6 +88,9 @@ class TestBaseMixin:
device = None
_SKIP_IF_NO_CUDA = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
if dtype not in ['float32', 'float64']:
raise NotImplementedError(f'Unexpected dtype: {dtype}')
......@@ -95,11 +99,10 @@ def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
name = f'Test{testbase.__name__}_{device.upper()}_{dtype.capitalize()}'
attrs = {'dtype': getattr(torch, dtype), 'device': torch.device(device)}
testsuite = type(name, (testbase,), attrs)
testsuite = type(name, (testbase, TestCase), attrs)
if device == 'cuda':
testsuite = pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA not available')(testsuite)
testsuite = _SKIP_IF_NO_CUDA(testsuite)
return testsuite
......
......@@ -23,7 +23,7 @@ class Lfilter(common_utils.TestBaseMixin):
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
torch.testing.assert_allclose(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)
self.assertEqual(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)
def test_clamp(self):
input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device)
......
"""Test suites for jit-ability and its numerical compatibility"""
import unittest
import pytest
import torch
import torchaudio
......@@ -10,29 +9,18 @@ import torchaudio.transforms as T
import common_utils
def _assert_functional_consistency(func, tensor, shape_only=False):
ts_func = torch.jit.script(func)
output = func(tensor)
ts_output = ts_func(tensor)
if shape_only:
assert ts_output.shape == output.shape, (ts_output.shape, output.shape)
else:
torch.testing.assert_allclose(ts_output, output)
def _assert_transforms_consistency(transform, tensor):
ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
torch.testing.assert_allclose(ts_output, output)
class Functional(common_utils.TestBaseMixin):
"""Implements test for `functinoal` modul that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)
return _assert_functional_consistency(func, tensor, shape_only=shape_only)
ts_func = torch.jit.script(func)
output = func(tensor)
ts_output = ts_func(tensor)
if shape_only:
ts_output = ts_output.shape
output = output.shape
self.assertEqual(ts_output, output)
def test_spectrogram(self):
def func(tensor):
......@@ -210,7 +198,7 @@ class Functional(common_utils.TestBaseMixin):
def test_lfilter(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -254,7 +242,7 @@ class Functional(common_utils.TestBaseMixin):
def test_lowpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -268,7 +256,7 @@ class Functional(common_utils.TestBaseMixin):
def test_highpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -282,7 +270,7 @@ class Functional(common_utils.TestBaseMixin):
def test_allpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -297,7 +285,7 @@ class Functional(common_utils.TestBaseMixin):
def test_bandpass_with_csg(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -313,7 +301,7 @@ class Functional(common_utils.TestBaseMixin):
def test_bandpass_without_csg(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -329,7 +317,7 @@ class Functional(common_utils.TestBaseMixin):
def test_bandreject(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -344,7 +332,7 @@ class Functional(common_utils.TestBaseMixin):
def test_band_with_noise(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -360,7 +348,7 @@ class Functional(common_utils.TestBaseMixin):
def test_band_without_noise(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -376,7 +364,7 @@ class Functional(common_utils.TestBaseMixin):
def test_treble(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -392,7 +380,7 @@ class Functional(common_utils.TestBaseMixin):
def test_deemph(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -405,7 +393,7 @@ class Functional(common_utils.TestBaseMixin):
def test_riaa(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -418,7 +406,7 @@ class Functional(common_utils.TestBaseMixin):
def test_equalizer(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -434,7 +422,7 @@ class Functional(common_utils.TestBaseMixin):
def test_perf_biquad_filtering(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -515,7 +503,7 @@ class Functional(common_utils.TestBaseMixin):
def test_phaser(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
gain_in = 0.5
......@@ -534,7 +522,11 @@ class Transforms(common_utils.TestBaseMixin):
def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)
_assert_transforms_consistency(transform, tensor)
ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)
def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment