Unverified Commit 413bd18e authored by moto's avatar moto Committed by GitHub
Browse files

Extract JIT tests from test_transforms to the dedicated test module (#496)

parent eb5b5a02
...@@ -5,6 +5,7 @@ import unittest ...@@ -5,6 +5,7 @@ import unittest
import torch import torch
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
import torchaudio.transforms
import common_utils import common_utils
...@@ -149,3 +150,109 @@ class TestFunctional(unittest.TestCase): ...@@ -149,3 +150,109 @@ class TestFunctional(unittest.TestCase):
_test_torchscript_functional_shape(F.dither, tensor) _test_torchscript_functional_shape(F.dither, tensor)
_test_torchscript_functional_shape(F.dither, tensor, "RPDF") _test_torchscript_functional_shape(F.dither, tensor, "RPDF")
_test_torchscript_functional_shape(F.dither, tensor, "GPDF") _test_torchscript_functional_shape(F.dither, tensor, "GPDF")
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)
def _test_script_module(f, tensor, *args, **kwargs):
py_method = f(*args, **kwargs)
jit_method = torch.jit.script(py_method)
py_out = py_method(tensor)
jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out)
if RUN_CUDA:
tensor = tensor.to("cuda")
py_method = py_method.cuda()
jit_method = torch.jit.script(py_method)
py_out = py_method(tensor)
jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out)
class TestTransforms(unittest.TestCase):
def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(torchaudio.transforms.Spectrogram, tensor)
def test_GriffinLim(self):
tensor = torch.rand((1, 201, 6))
_test_script_module(torchaudio.transforms.GriffinLim, tensor, length=1000, rand_init=False)
def test_AmplitudeToDB(self):
spec = torch.rand((6, 201))
_test_script_module(torchaudio.transforms.AmplitudeToDB, spec)
def test_MelScale(self):
spec_f = torch.rand((1, 6, 201))
_test_script_module(torchaudio.transforms.MelScale, spec_f)
def test_MelSpectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(torchaudio.transforms.MelSpectrogram, tensor)
def test_MFCC(self):
tensor = torch.rand((1, 1000))
_test_script_module(torchaudio.transforms.MFCC, tensor)
def test_Resample(self):
tensor = torch.rand((2, 1000))
sample_rate = 100.
sample_rate_2 = 50.
_test_script_module(torchaudio.transforms.Resample, tensor, sample_rate, sample_rate_2)
def test_ComplexNorm(self):
tensor = torch.rand((1, 2, 201, 2))
_test_script_module(torchaudio.transforms.ComplexNorm, tensor)
def test_MuLawEncoding(self):
tensor = torch.rand((1, 10))
_test_script_module(torchaudio.transforms.MuLawEncoding, tensor)
def test_MuLawDecoding(self):
tensor = torch.rand((1, 10))
_test_script_module(torchaudio.transforms.MuLawDecoding, tensor)
def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
_test_script_module(
torchaudio.transforms.TimeStretch,
tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
def test_Fade(self):
test_filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath)
fade_in_len = 3000
fade_out_len = 3000
_test_script_module(torchaudio.transforms.Fade, waveform, fade_in_len, fade_out_len)
def test_FrequencyMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(
torchaudio.transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)
def test_TimeMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(
torchaudio.transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
def test_Vol(self):
test_filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath)
_test_script_module(torchaudio.transforms.Vol, waveform, 1.1)
...@@ -10,33 +10,6 @@ import torchaudio.functional as F ...@@ -10,33 +10,6 @@ import torchaudio.functional as F
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)
def _test_script_module(f, tensor, *args, **kwargs):
py_method = f(*args, **kwargs)
jit_method = torch.jit.script(py_method)
py_out = py_method(tensor)
jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out)
if RUN_CUDA:
tensor = tensor.to("cuda")
py_method = py_method.cuda()
jit_method = torch.jit.script(py_method)
py_out = py_method(tensor)
jit_out = jit_method(tensor)
assert torch.allclose(jit_out, py_out)
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
# create a sinewave signal for testing # create a sinewave signal for testing
...@@ -57,14 +30,6 @@ class Tester(unittest.TestCase): ...@@ -57,14 +30,6 @@ class Tester(unittest.TestCase):
waveform = waveform.to(torch.get_default_dtype()) waveform = waveform.to(torch.get_default_dtype())
return waveform / factor return waveform / factor
def test_scriptmodule_Spectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.Spectrogram, tensor)
def test_scriptmodule_GriffinLim(self):
tensor = torch.rand((1, 201, 6))
_test_script_module(transforms.GriffinLim, tensor, length=1000, rand_init=False)
def test_mu_law_companding(self): def test_mu_law_companding(self):
quantization_channels = 256 quantization_channels = 256
...@@ -79,10 +44,6 @@ class Tester(unittest.TestCase): ...@@ -79,10 +44,6 @@ class Tester(unittest.TestCase):
waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu) waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
def test_scriptmodule_AmplitudeToDB(self):
spec = torch.rand((6, 201))
_test_script_module(transforms.AmplitudeToDB, spec)
def test_batch_AmplitudeToDB(self): def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201)) spec = torch.rand((6, 201))
...@@ -106,10 +67,6 @@ class Tester(unittest.TestCase): ...@@ -106,10 +67,6 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch)) self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch))
def test_scriptmodule_MelScale(self):
spec_f = torch.rand((1, 6, 201))
_test_script_module(transforms.MelScale, spec_f)
def test_melscale_load_save(self): def test_melscale_load_save(self):
specgram = torch.ones(1, 1000, 100) specgram = torch.ones(1, 1000, 100)
melscale_transform = transforms.MelScale() melscale_transform = transforms.MelScale()
...@@ -124,10 +81,6 @@ class Tester(unittest.TestCase): ...@@ -124,10 +81,6 @@ class Tester(unittest.TestCase):
self.assertEqual(fb_copy.size(), (1000, 128)) self.assertEqual(fb_copy.size(), (1000, 128))
self.assertTrue(torch.allclose(fb, fb_copy)) self.assertTrue(torch.allclose(fb, fb_copy))
def test_scriptmodule_MelSpectrogram(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.MelSpectrogram, tensor)
def test_melspectrogram_load_save(self): def test_melspectrogram_load_save(self):
waveform = self.waveform.float() waveform = self.waveform.float()
mel_spectrogram_transform = transforms.MelSpectrogram() mel_spectrogram_transform = transforms.MelSpectrogram()
...@@ -186,10 +139,6 @@ class Tester(unittest.TestCase): ...@@ -186,10 +139,6 @@ class Tester(unittest.TestCase):
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_scriptmodule_MFCC(self):
tensor = torch.rand((1, 1000))
_test_script_module(transforms.MFCC, tensor)
def test_mfcc(self): def test_mfcc(self):
audio_orig = self.waveform.clone() audio_orig = self.waveform.clone()
audio_scaled = self.scale(audio_orig) # (1, 16000) audio_scaled = self.scale(audio_orig) # (1, 16000)
...@@ -226,13 +175,6 @@ class Tester(unittest.TestCase): ...@@ -226,13 +175,6 @@ class Tester(unittest.TestCase):
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
def test_scriptmodule_Resample(self):
tensor = torch.rand((2, 1000))
sample_rate = 100.
sample_rate_2 = 50.
_test_script_module(transforms.Resample, tensor, sample_rate, sample_rate_2)
def test_batch_Resample(self): def test_batch_Resample(self):
waveform = torch.randn(2, 2786) waveform = torch.randn(2, 2786)
...@@ -245,10 +187,6 @@ class Tester(unittest.TestCase): ...@@ -245,10 +187,6 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected)) self.assertTrue(torch.allclose(computed, expected))
def test_scriptmodule_ComplexNorm(self):
tensor = torch.rand((1, 2, 201, 2))
_test_script_module(transforms.ComplexNorm, tensor)
def test_resample_size(self): def test_resample_size(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
waveform, sample_rate = torchaudio.load(input_path) waveform, sample_rate = torchaudio.load(input_path)
...@@ -349,14 +287,6 @@ class Tester(unittest.TestCase): ...@@ -349,14 +287,6 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected)) self.assertTrue(torch.allclose(computed, expected))
def test_scriptmodule_MuLawEncoding(self):
tensor = torch.rand((1, 10))
_test_script_module(transforms.MuLawEncoding, tensor)
def test_scriptmodule_MuLawDecoding(self):
tensor = torch.rand((1, 10))
_test_script_module(transforms.MuLawDecoding, tensor)
def test_batch_mulaw(self): def test_batch_mulaw(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
...@@ -424,13 +354,6 @@ class Tester(unittest.TestCase): ...@@ -424,13 +354,6 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected, atol=1e-5)) self.assertTrue(torch.allclose(computed, expected, atol=1e-5))
def test_scriptmodule_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
_test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
def test_batch_TimeStretch(self): def test_batch_TimeStretch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) waveform, sample_rate = torchaudio.load(self.test_filepath)
...@@ -475,26 +398,6 @@ class Tester(unittest.TestCase): ...@@ -475,26 +398,6 @@ class Tester(unittest.TestCase):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected)) self.assertTrue(torch.allclose(computed, expected))
def test_scriptmodule_Fade(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
fade_in_len = 3000
fade_out_len = 3000
_test_script_module(transforms.Fade, waveform, fade_in_len, fade_out_len)
def test_scriptmodule_FrequencyMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)
def test_scriptmodule_TimeMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
def test_scriptmodule_Vol(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
_test_script_module(transforms.Vol, waveform, 1.1)
def test_batch_Vol(self): def test_batch_Vol(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) waveform, sample_rate = torchaudio.load(self.test_filepath)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment