import math import os import torch import torchaudio import torchaudio.transforms as transforms import torchaudio.functional as F from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY import unittest import common_utils if IMPORT_LIBROSA: import librosa if IMPORT_SCIPY: import scipy RUN_CUDA = torch.cuda.is_available() print("Run test with cuda:", RUN_CUDA) def _test_script_module(f, tensor, *args, **kwargs): py_method = f(*args, **kwargs) jit_method = torch.jit.script(py_method) py_out = py_method(tensor) jit_out = jit_method(tensor) assert torch.allclose(jit_out, py_out) if RUN_CUDA: tensor = tensor.to("cuda") py_method = py_method.cuda() jit_method = torch.jit.script(py_method) py_out = py_method(tensor) jit_out = jit_method(tensor) assert torch.allclose(jit_out, py_out) class Tester(unittest.TestCase): # create a sinewave signal for testing sample_rate = 16000 freq = 440 volume = .3 waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate)) waveform.unsqueeze_(0) # (1, 64000) waveform = (waveform * volume * 2**31).long() # file for stereo stft test test_dirpath, test_dir = common_utils.create_temp_assets_dir() test_filepath = os.path.join(test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3') def scale(self, waveform, factor=float(2**31)): # scales a waveform by a factor if not waveform.is_floating_point(): waveform = waveform.to(torch.get_default_dtype()) return waveform / factor def test_scriptmodule_Spectrogram(self): tensor = torch.rand((1, 1000)) _test_script_module(transforms.Spectrogram, tensor) def test_scriptmodule_GriffinLim(self): tensor = torch.rand((1, 201, 6)) _test_script_module(transforms.GriffinLim, tensor, length=1000, rand_init=False) def test_mu_law_companding(self): quantization_channels = 256 waveform = self.waveform.clone() waveform /= torch.abs(waveform).max() self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.) waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform) self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels) waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) def test_scriptmodule_AmplitudeToDB(self): spec = torch.rand((6, 201)) _test_script_module(transforms.AmplitudeToDB, spec) def test_batch_AmplitudeToDB(self): spec = torch.rand((6, 201)) # Single then transform then batch expected = transforms.AmplitudeToDB()(spec).repeat(3, 1, 1) # Batch then transform computed = transforms.AmplitudeToDB()(spec.repeat(3, 1, 1)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) def test_AmplitudeToDB(self): waveform, sample_rate = torchaudio.load(self.test_filepath) mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.) power_to_db_transform = transforms.AmplitudeToDB('power', 80.) mag_to_db_torch = mag_to_db_transform(torch.abs(waveform)) power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2)) self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch)) def test_scriptmodule_MelScale(self): spec_f = torch.rand((1, 6, 201)) _test_script_module(transforms.MelScale, spec_f) def test_melscale_load_save(self): specgram = torch.ones(1, 1000, 100) melscale_transform = transforms.MelScale() melscale_transform(specgram) melscale_transform_copy = transforms.MelScale(n_stft=1000) melscale_transform_copy.load_state_dict(melscale_transform.state_dict()) fb = melscale_transform.fb fb_copy = melscale_transform_copy.fb self.assertEqual(fb_copy.size(), (1000, 128)) self.assertTrue(torch.allclose(fb, fb_copy)) def test_scriptmodule_MelSpectrogram(self): tensor = torch.rand((1, 1000)) _test_script_module(transforms.MelSpectrogram, tensor) def test_melspectrogram_load_save(self): waveform = self.waveform.float() mel_spectrogram_transform = transforms.MelSpectrogram() mel_spectrogram_transform(waveform) mel_spectrogram_transform_copy = transforms.MelSpectrogram() mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict()) window = mel_spectrogram_transform.spectrogram.window window_copy = mel_spectrogram_transform_copy.spectrogram.window fb = mel_spectrogram_transform.mel_scale.fb fb_copy = mel_spectrogram_transform_copy.mel_scale.fb self.assertTrue(torch.allclose(window, window_copy)) # the default for n_fft = 400 and n_mels = 128 self.assertEqual(fb_copy.size(), (201, 128)) self.assertTrue(torch.allclose(fb, fb_copy)) def test_mel2(self): top_db = 80. s2db = transforms.AmplitudeToDB('power', top_db) waveform = self.waveform.clone() # (1, 16000) waveform_scaled = self.scale(waveform) # (1, 16000) mel_transform = transforms.MelSpectrogram() # check defaults spectrogram_torch = s2db(mel_transform(waveform_scaled)) # (1, 128, 321) self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels) # check correctness of filterbank conversion matrix self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all()) # check options kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500, 'hop_length': 125, 'n_fft': 800, 'n_mels': 50} mel_transform2 = transforms.MelSpectrogram(**kwargs) spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513) self.assertTrue(spectrogram2_torch.dim() == 3) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all()) # check on multi-channel audio x_stereo, sr_stereo = torchaudio.load(self.test_filepath) # (2, 278756), 44100 spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394) self.assertTrue(spectrogram_stereo.dim() == 3) self.assertTrue(spectrogram_stereo.size(0) == 2) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels) # check filterbank matrix creation fb_matrix_transform = transforms.MelScale( n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400) self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) def test_scriptmodule_MFCC(self): tensor = torch.rand((1, 1000)) _test_script_module(transforms.MFCC, tensor) def test_mfcc(self): audio_orig = self.waveform.clone() audio_scaled = self.scale(audio_orig) # (1, 16000) sample_rate = 16000 n_mfcc = 40 n_mels = 128 mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho') # check defaults torch_mfcc = mfcc_transform(audio_scaled) # (1, 40, 321) self.assertTrue(torch_mfcc.dim() == 3) self.assertTrue(torch_mfcc.shape[1] == n_mfcc) self.assertTrue(torch_mfcc.shape[2] == 321) # check melkwargs are passed through melkwargs = {'win_length': 200} mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs) torch_mfcc2 = mfcc_transform2(audio_scaled) # (1, 40, 641) self.assertTrue(torch_mfcc2.shape[2] == 641) # check norms work correctly mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm=None) torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) # (1, 40, 321) norm_check = torch_mfcc.clone() norm_check[:, 0, :] *= math.sqrt(n_mels) * 2 norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2 self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available') def test_librosa_consistency(self): def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') sound, sample_rate = torchaudio.load(input_path) sound_librosa = sound.cpu().numpy().squeeze() # (64000) # test core spectrogram spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=power) out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power) out_torch = spect_transform(sound).squeeze().cpu() self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)) # test mel spectrogram melspect_transform = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, window_fn=torch.hann_window, hop_length=hop_length, n_mels=n_mels, n_fft=n_fft) librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, htk=True, norm=None) librosa_mel_tensor = torch.from_numpy(librosa_mel) torch_mel = melspect_transform(sound).squeeze().cpu() self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) # test s2db power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.) power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu() power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa) self.assertTrue(torch.allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3)) mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.) mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu() mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa) self.assertTrue( torch.allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3) ) power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu() db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) db_librosa_tensor = torch.from_numpy(db_librosa) self.assertTrue( torch.allclose(power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3) ) # test MFCC melkwargs = {'hop_length': hop_length, 'n_fft': n_fft} mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs) # librosa.feature.mfcc doesn't pass kwargs properly since some of the # kwargs for melspectrogram and mfcc are the same. We just follow the # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram # to mirror this function call with correct args: # librosa_mfcc = librosa.feature.mfcc(y=sound_librosa, # sr=sample_rate, # n_mfcc = n_mfcc, # hop_length=hop_length, # n_fft=n_fft, # htk=True, # norm=None, # n_mels=n_mels) librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc] librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc) torch_mfcc = mfcc_transform(sound).squeeze().cpu() self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3)) kwargs1 = { 'n_fft': 400, 'hop_length': 200, 'power': 2.0, 'n_mels': 128, 'n_mfcc': 40, 'sample_rate': 16000 } kwargs2 = { 'n_fft': 600, 'hop_length': 100, 'power': 2.0, 'n_mels': 128, 'n_mfcc': 20, 'sample_rate': 16000 } kwargs3 = { 'n_fft': 200, 'hop_length': 50, 'power': 2.0, 'n_mels': 128, 'n_mfcc': 50, 'sample_rate': 24000 } kwargs4 = { 'n_fft': 400, 'hop_length': 200, 'power': 3.0, 'n_mels': 128, 'n_mfcc': 40, 'sample_rate': 16000 } _test_librosa_consistency_helper(**kwargs1) _test_librosa_consistency_helper(**kwargs2) # NOTE Test passes offline, but fails on CircleCI, see #372. # _test_librosa_consistency_helper(**kwargs3) _test_librosa_consistency_helper(**kwargs4) def test_scriptmodule_Resample(self): tensor = torch.rand((2, 1000)) sample_rate = 100. sample_rate_2 = 50. _test_script_module(transforms.Resample, tensor, sample_rate, sample_rate_2) def test_batch_Resample(self): waveform = torch.randn(2, 2786) # Single then transform then batch expected = transforms.Resample()(waveform).repeat(3, 1, 1) # Batch then transform computed = transforms.Resample()(waveform.repeat(3, 1, 1)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) def test_scriptmodule_ComplexNorm(self): tensor = torch.rand((1, 2, 201, 2)) _test_script_module(transforms.ComplexNorm, tensor) def test_resample_size(self): input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') waveform, sample_rate = torchaudio.load(input_path) upsample_rate = sample_rate * 2 downsample_rate = sample_rate // 2 invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo') self.assertRaises(ValueError, invalid_resample, waveform) upsample_resample = torchaudio.transforms.Resample( sample_rate, upsample_rate, resampling_method='sinc_interpolation') up_sampled = upsample_resample(waveform) # we expect the upsampled signal to have twice as many samples self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2) downsample_resample = torchaudio.transforms.Resample( sample_rate, downsample_rate, resampling_method='sinc_interpolation') down_sampled = downsample_resample(waveform) # we expect the downsampled signal to have half as many samples self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2) def test_compute_deltas(self): channel = 13 n_mfcc = channel * 3 time = 1021 win_length = 2 * 7 + 1 specgram = torch.randn(channel, n_mfcc, time) transform = transforms.ComputeDeltas(win_length=win_length) computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8): channel = 13 n_mfcc = channel * 3 time = 1021 win_length = 2 * 7 + 1 specgram = torch.randn(channel, n_mfcc, time) transform = transforms.ComputeDeltas(win_length=win_length) computed_transform = transform(specgram) computed_functional = F.compute_deltas(specgram, win_length=win_length) torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol) def test_compute_deltas_twochannel(self): specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1) expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], [0.5, 1.0, 1.0, 0.5]]]) transform = transforms.ComputeDeltas() computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) def test_batch_MelScale(self): specgram = torch.randn(2, 31, 2786) # Single then transform then batch expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1) # Batch then transform computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1)) # shape = (3, 2, 201, 1394) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) def test_batch_InverseMelScale(self): n_fft = 8 n_mels = 32 n_stft = 5 mel_spec = torch.randn(2, n_mels, 32) ** 2 # Single then transform then batch expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1) # Batch then transform computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1)) # shape = (3, 2, n_mels, 32) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) # Because InverseMelScale runs SGD on randomly initialized values so they do not yield # exactly same result. For this reason, tolerance is very relaxed here. self.assertTrue(torch.allclose(computed, expected, atol=1.0)) def test_batch_compute_deltas(self): specgram = torch.randn(2, 31, 2786) # Single then transform then batch expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1) # Batch then transform computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1)) # shape = (3, 2, 201, 1394) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) def test_scriptmodule_MuLawEncoding(self): tensor = torch.rand((1, 10)) _test_script_module(transforms.MuLawEncoding, tensor) def test_scriptmodule_MuLawDecoding(self): tensor = torch.rand((1, 10)) _test_script_module(transforms.MuLawDecoding, tensor) def test_batch_mulaw(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 # Single then transform then batch waveform_encoded = transforms.MuLawEncoding()(waveform) expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1) # Batch then transform waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1) computed = transforms.MuLawEncoding()(waveform_batched) # shape = (3, 2, 201, 1394) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) # Single then transform then batch waveform_decoded = transforms.MuLawDecoding()(waveform_encoded) expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1) # Batch then transform computed = transforms.MuLawDecoding()(computed) # shape = (3, 2, 201, 1394) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) def test_batch_spectrogram(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # Single then transform then batch expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1) # Batch then transform computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) def test_batch_melspectrogram(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # Single then transform then batch expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1) # Batch then transform computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) def test_batch_mfcc(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # Single then transform then batch expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1) # Batch then transform computed = transforms.MFCC()(waveform.repeat(3, 1, 1)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected, atol=1e-5)) def test_scriptmodule_TimeStretch(self): n_freq = 400 hop_length = 512 fixed_rate = 1.3 tensor = torch.rand((10, 2, n_freq, 10, 2)) _test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate) def test_batch_TimeStretch(self): waveform, sample_rate = torchaudio.load(self.test_filepath) kwargs = { 'n_fft': 2048, 'hop_length': 512, 'win_length': 2048, 'window': torch.hann_window(2048), 'center': True, 'pad_mode': 'reflect', 'normalized': True, 'onesided': True, } rate = 2 complex_specgrams = torch.stft(waveform, **kwargs) # Single then transform then batch expected = transforms.TimeStretch(fixed_rate=rate, n_freq=1025, hop_length=512)(complex_specgrams).repeat(3, 1, 1, 1, 1) # Batch then transform computed = transforms.TimeStretch(fixed_rate=rate, n_freq=1025, hop_length=512)(complex_specgrams.repeat(3, 1, 1, 1, 1)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected, atol=1e-5)) def test_batch_Fade(self): waveform, sample_rate = torchaudio.load(self.test_filepath) fade_in_len = 3000 fade_out_len = 3000 # Single then transform then batch expected = transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1) # Batch then transform computed = transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) def test_scriptmodule_Fade(self): waveform, sample_rate = torchaudio.load(self.test_filepath) fade_in_len = 3000 fade_out_len = 3000 _test_script_module(transforms.Fade, waveform, fade_in_len, fade_out_len) def test_scriptmodule_FrequencyMasking(self): tensor = torch.rand((10, 2, 50, 10, 2)) _test_script_module(transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False) def test_scriptmodule_TimeMasking(self): tensor = torch.rand((10, 2, 50, 10, 2)) _test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False) def test_scriptmodule_Vol(self): waveform, sample_rate = torchaudio.load(self.test_filepath) _test_script_module(transforms.Vol, waveform, 1.1) def test_batch_Vol(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # Single then transform then batch expected = transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1) # Batch then transform computed = transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1)) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) class TestLibrosaConsistency(unittest.TestCase): test_dirpath = None test_dir = None @classmethod def setUpClass(cls): cls.test_dirpath, cls.test_dir = common_utils.create_temp_assets_dir() def _to_librosa(self, sound): return sound.cpu().numpy().squeeze() def _get_sample_data(self, *asset_paths, **kwargs): file_path = os.path.join(self.test_dirpath, 'assets', *asset_paths) sound, sample_rate = torchaudio.load(file_path, **kwargs) return sound.mean(dim=0, keepdim=True), sample_rate @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available') def test_MelScale(self): """MelScale transform is comparable to that of librosa""" n_fft = 2048 n_mels = 256 hop_length = n_fft // 4 # Prepare spectrogram input. We use torchaudio to compute one. sound, sample_rate = self._get_sample_data('whitenoise_1min.mp3') spec_ta = F.spectrogram( sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft, hop_length=hop_length, win_length=n_fft, power=2, normalized=False) spec_lr = spec_ta.cpu().numpy().squeeze() # Perform MelScale with torchaudio and librosa melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta) melspec_lr = librosa.feature.melspectrogram( S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None) # Note: Using relaxed rtol instead of atol assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3) @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available') def test_InverseMelScale(self): """InverseMelScale transform is comparable to that of librosa""" n_fft = 2048 n_mels = 256 n_stft = n_fft // 2 + 1 hop_length = n_fft // 4 # Prepare mel spectrogram input. We use torchaudio to compute one. sound, sample_rate = self._get_sample_data( 'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14) spec_orig = F.spectrogram( sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft, hop_length=hop_length, win_length=n_fft, power=2, normalized=False) melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig) melspec_lr = melspec_ta.cpu().numpy().squeeze() # Perform InverseMelScale with torch audio and librosa spec_ta = transforms.InverseMelScale( n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta) spec_lr = librosa.feature.inverse.mel_to_stft( melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None) spec_lr = torch.from_numpy(spec_lr[None, ...]) # Align dimensions # librosa does not return power spectrogram while torchaudio returns power spectrogram spec_orig = spec_orig.sqrt() spec_ta = spec_ta.sqrt() threshold = 2.0 # This threshold was choosen empirically, based on the following observation # # torch.dist(spec_lr, spec_ta, p=float('inf')) # >>> tensor(1.9666) # # The spectrograms reconstructed by librosa and torchaudio are not very comparable elementwise. # This is because they use different approximation algorithms and resulting values can live # in different magnitude. (although most of them are very close) # See https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm # See https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf # distance over frequencies. assert torch.allclose(spec_ta, spec_lr, atol=threshold) threshold = 1700.0 # This threshold was choosen empirically, based on the following observations # # torch.dist(spec_orig, spec_ta, p=1) # >>> tensor(1644.3516) # torch.dist(spec_orig, spec_lr, p=1) # >>> tensor(1420.7103) # torch.dist(spec_lr, spec_ta, p=1) # >>> tensor(943.2759) assert torch.dist(spec_orig, spec_ta, p=1) < threshold if __name__ == '__main__': unittest.main()