Unverified Commit b5b6d30f authored by moto's avatar moto Committed by GitHub
Browse files

Separate CPU and GPU tests for Transforms torchscript test (#520)

parent c29598d5
......@@ -5,11 +5,20 @@ import unittest
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms
import torchaudio.transforms as T
import common_utils
def _assert_transforms_consistency(transform, tensor, device):
tensor = tensor.to(device)
transform = transform.to(device)
ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
torch.testing.assert_allclose(ts_output, output)
def _assert_functional_consistency(py_method, *args, shape_only=False, **kwargs):
jit_method = torch.jit.script(py_method)
......@@ -301,85 +310,64 @@ class TestFunctional(unittest.TestCase):
_assert_functional_consistency(F.lfilter, waveform, a, b)
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)
def _test_script_module(f, tensor, *args, **kwargs):
py_method = f(*args, **kwargs)
jit_method = torch.jit.script(py_method)
py_out = py_method(tensor)
jit_out = jit_method(tensor)
torch.testing.assert_allclose(jit_out, py_out)
if RUN_CUDA:
tensor = tensor.to("cuda")
py_method = py_method.cuda()
jit_method = torch.jit.script(py_method)
class _TransformsTestMixin:
"""Implements test for Transforms that are performed for different devices"""
device = None
py_out = py_method(tensor)
jit_out = jit_method(tensor)
def _assert_consistency(self, transform, tensor):
_assert_transforms_consistency(transform, tensor, self.device)
torch.testing.assert_allclose(jit_out, py_out)
class TestTransforms(unittest.TestCase):
def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(torchaudio.transforms.Spectrogram, tensor)
self._assert_consistency(T.Spectrogram(), tensor)
def test_GriffinLim(self):
tensor = torch.rand((1, 201, 6))
_test_script_module(torchaudio.transforms.GriffinLim, tensor, length=1000, rand_init=False)
self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor)
def test_AmplitudeToDB(self):
spec = torch.rand((6, 201))
_test_script_module(torchaudio.transforms.AmplitudeToDB, spec)
self._assert_consistency(T.AmplitudeToDB(), spec)
def test_MelScale(self):
spec_f = torch.rand((1, 6, 201))
_test_script_module(torchaudio.transforms.MelScale, spec_f)
self._assert_consistency(T.MelScale(), spec_f)
def test_MelSpectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(torchaudio.transforms.MelSpectrogram, tensor)
self._assert_consistency(T.MelSpectrogram(), tensor)
def test_MFCC(self):
tensor = torch.rand((1, 1000))
_test_script_module(torchaudio.transforms.MFCC, tensor)
self._assert_consistency(T.MFCC(), tensor)
def test_Resample(self):
tensor = torch.rand((2, 1000))
sample_rate = 100.
sample_rate_2 = 50.
_test_script_module(torchaudio.transforms.Resample, tensor, sample_rate, sample_rate_2)
self._assert_consistency(T.Resample(sample_rate, sample_rate_2), tensor)
def test_ComplexNorm(self):
tensor = torch.rand((1, 2, 201, 2))
_test_script_module(torchaudio.transforms.ComplexNorm, tensor)
self._assert_consistency(T.ComplexNorm(), tensor)
def test_MuLawEncoding(self):
tensor = torch.rand((1, 10))
_test_script_module(torchaudio.transforms.MuLawEncoding, tensor)
self._assert_consistency(T.MuLawEncoding(), tensor)
def test_MuLawDecoding(self):
tensor = torch.rand((1, 10))
_test_script_module(torchaudio.transforms.MuLawDecoding, tensor)
self._assert_consistency(T.MuLawDecoding(), tensor)
def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
_test_script_module(
torchaudio.transforms.TimeStretch,
tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)
def test_Fade(self):
test_filepath = os.path.join(
......@@ -387,24 +375,32 @@ class TestTransforms(unittest.TestCase):
waveform, _ = torchaudio.load(test_filepath)
fade_in_len = 3000
fade_out_len = 3000
_test_script_module(torchaudio.transforms.Fade, waveform, fade_in_len, fade_out_len)
self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform)
def test_FrequencyMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(
torchaudio.transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)
self._assert_consistency(T.FrequencyMasking(freq_mask_param=60, iid_masks=False), tensor)
def test_TimeMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(
torchaudio.transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor)
def test_Vol(self):
test_filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath)
_test_script_module(torchaudio.transforms.Vol, waveform, 1.1)
self._assert_consistency(T.Vol(1.1), waveform)
class TestTransformsCPU(_TransformsTestMixin, unittest.TestCase):
"""Test suite for Transforms module on CPU"""
device = torch.device('cpu')
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestTransformsCUDA(_TransformsTestMixin, unittest.TestCase):
"""Test suite for Transforms module on GPU"""
device = torch.device('cuda')
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment