"docs/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "d60c52df4abbf3cf4e7048f687433012a777e678"
Commit b29a4639 authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

[BC] Standardization of Transforms/Functionals (#152)

parent 2271a7ae
...@@ -2,6 +2,8 @@ import math ...@@ -2,6 +2,8 @@ import math
import torch import torch
import torchaudio import torchaudio
import torchaudio.functional as F
import pytest
import unittest import unittest
import test.common_utils import test.common_utils
...@@ -11,10 +13,6 @@ if IMPORT_LIBROSA: ...@@ -11,10 +13,6 @@ if IMPORT_LIBROSA:
import numpy as np import numpy as np
import librosa import librosa
import pytest
import torchaudio.functional as F
xfail = pytest.mark.xfail
class TestFunctional(unittest.TestCase): class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)] data_sizes = [(2, 20), (3, 15), (4, 10)]
...@@ -197,54 +195,6 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad): ...@@ -197,54 +195,6 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length return (signal_len + 2 * pad - fft_len + hop_length) // hop_length
@pytest.mark.parametrize('fft_length', [512])
@pytest.mark.parametrize('hop_length', [256])
@pytest.mark.parametrize('waveform', [
(torch.randn(1, 100000)),
(torch.randn(1, 2, 100000)),
pytest.param(torch.randn(1, 100), marks=xfail(raises=RuntimeError)),
])
@pytest.mark.parametrize('pad_mode', [
# 'constant',
'reflect',
])
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
def test_stft(waveform, fft_length, hop_length, pad_mode):
"""
Test STFT for multi-channel signals.
Padding: Value in having padding outside of torch.stft?
"""
pad = fft_length // 2
window = torch.hann_window(fft_length)
complex_spec = F.stft(waveform,
fft_length=fft_length,
hop_length=hop_length,
window=window,
pad_mode=pad_mode)
mag_spec, phase_spec = F.magphase(complex_spec)
# == Test shape
expected_size = list(waveform.size()[:-1])
expected_size += [fft_length // 2 + 1, _num_stft_bins(
waveform.size(-1), fft_length, hop_length, pad), 2]
assert complex_spec.dim() == waveform.dim() + 2
assert complex_spec.size() == torch.Size(expected_size)
# == Test values
fft_config = dict(n_fft=fft_length, hop_length=hop_length, pad_mode=pad_mode)
# note that librosa *automatically* pad with fft_length // 2.
expected_complex_spec = np.apply_along_axis(librosa.stft, -1,
waveform.numpy(), **fft_config)
expected_mag_spec, _ = librosa.magphase(expected_complex_spec)
# Convert torch to np.complex
complex_spec = complex_spec.numpy()
complex_spec = complex_spec[..., 0] + 1j * complex_spec[..., 1]
assert np.allclose(complex_spec, expected_complex_spec, atol=1e-5)
assert np.allclose(mag_spec.numpy(), expected_mag_spec, atol=1e-5)
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3]) @pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('complex_specgrams', [ @pytest.mark.parametrize('complex_specgrams', [
torch.randn(1, 2, 1025, 400, 2), torch.randn(1, 2, 1025, 400, 2),
......
...@@ -30,40 +30,18 @@ class Test_JIT(unittest.TestCase): ...@@ -30,40 +30,18 @@ class Test_JIT(unittest.TestCase):
self.assertTrue(torch.allclose(jit_out, py_out)) self.assertTrue(torch.allclose(jit_out, py_out))
def test_torchscript_scale(self):
@torch.jit.script
def jit_method(tensor, factor):
# type: (Tensor, int) -> Tensor
return F.scale(tensor, factor)
tensor = torch.rand((10, 1))
factor = 2
jit_out = jit_method(tensor, factor)
py_out = F.scale(tensor, factor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_scale(self):
tensor = torch.rand((10, 1), device="cuda")
self._test_script_module(tensor, transforms.Scale)
def test_torchscript_pad_trim(self): def test_torchscript_pad_trim(self):
@torch.jit.script @torch.jit.script
def jit_method(tensor, ch_dim, max_len, len_dim, fill_value): def jit_method(tensor, max_len, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor # type: (Tensor, int, float) -> Tensor
return F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value) return F.pad_trim(tensor, max_len, fill_value)
tensor = torch.rand((10, 1)) tensor = torch.rand((1, 10))
ch_dim = 1
max_len = 5 max_len = 5
len_dim = 0
fill_value = 3. fill_value = 3.
jit_out = jit_method(tensor, ch_dim, max_len, len_dim, fill_value) jit_out = jit_method(tensor, max_len, fill_value)
py_out = F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value) py_out = F.pad_trim(tensor, max_len, fill_value)
self.assertTrue(torch.allclose(jit_out, py_out)) self.assertTrue(torch.allclose(jit_out, py_out))
...@@ -74,45 +52,6 @@ class Test_JIT(unittest.TestCase): ...@@ -74,45 +52,6 @@ class Test_JIT(unittest.TestCase):
self._test_script_module(tensor, transforms.PadTrim, max_len) self._test_script_module(tensor, transforms.PadTrim, max_len)
def test_torchscript_downmix_mono(self):
@torch.jit.script
def jit_method(tensor, ch_dim):
# type: (Tensor, int) -> Tensor
return F.downmix_mono(tensor, ch_dim)
tensor = torch.rand((10, 1))
ch_dim = 1
jit_out = jit_method(tensor, ch_dim)
py_out = F.downmix_mono(tensor, ch_dim)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_downmix_mono(self):
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.DownmixMono)
def test_torchscript_LC2CL(self):
@torch.jit.script
def jit_method(tensor):
# type: (Tensor) -> Tensor
return F.LC2CL(tensor)
tensor = torch.rand((10, 1))
jit_out = jit_method(tensor)
py_out = F.LC2CL(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_LC2CL(self):
tensor = torch.rand((10, 1), device="cuda")
self._test_script_module(tensor, transforms.LC2CL)
def test_torchscript_spectrogram(self): def test_torchscript_spectrogram(self):
@torch.jit.script @torch.jit.script
def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize): def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize):
...@@ -167,7 +106,7 @@ class Test_JIT(unittest.TestCase): ...@@ -167,7 +106,7 @@ class Test_JIT(unittest.TestCase):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor # type: (Tensor, float, float, float, Optional[float]) -> Tensor
return F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db) return F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db)
spec = torch.rand((10, 1)) spec = torch.rand((6, 201))
multiplier = 10. multiplier = 10.
amin = 1e-10 amin = 1e-10
db_multiplier = 0. db_multiplier = 0.
...@@ -180,7 +119,7 @@ class Test_JIT(unittest.TestCase): ...@@ -180,7 +119,7 @@ class Test_JIT(unittest.TestCase):
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_SpectrogramToDB(self): def test_scriptmodule_SpectrogramToDB(self):
spec = torch.rand((10, 1), device="cuda") spec = torch.rand((6, 201), device="cuda")
self._test_script_module(spec, transforms.SpectrogramToDB) self._test_script_module(spec, transforms.SpectrogramToDB)
...@@ -211,32 +150,13 @@ class Test_JIT(unittest.TestCase): ...@@ -211,32 +150,13 @@ class Test_JIT(unittest.TestCase):
self._test_script_module(tensor, transforms.MelSpectrogram) self._test_script_module(tensor, transforms.MelSpectrogram)
def test_torchscript_BLC2CBL(self):
@torch.jit.script
def jit_method(tensor):
# type: (Tensor) -> Tensor
return F.BLC2CBL(tensor)
tensor = torch.rand((10, 1000, 1))
jit_out = jit_method(tensor)
py_out = F.BLC2CBL(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_BLC2CBL(self):
tensor = torch.rand((10, 1000, 1), device="cuda")
self._test_script_module(tensor, transforms.BLC2CBL)
def test_torchscript_mu_law_encoding(self): def test_torchscript_mu_law_encoding(self):
@torch.jit.script @torch.jit.script
def jit_method(tensor, qc): def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
return F.mu_law_encoding(tensor, qc) return F.mu_law_encoding(tensor, qc)
tensor = torch.rand((10, 1)) tensor = torch.rand((1, 10))
qc = 256 qc = 256
jit_out = jit_method(tensor, qc) jit_out = jit_method(tensor, qc)
...@@ -246,7 +166,7 @@ class Test_JIT(unittest.TestCase): ...@@ -246,7 +166,7 @@ class Test_JIT(unittest.TestCase):
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawEncoding(self): def test_scriptmodule_MuLawEncoding(self):
tensor = torch.rand((10, 1), device="cuda") tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawEncoding) self._test_script_module(tensor, transforms.MuLawEncoding)
...@@ -256,7 +176,7 @@ class Test_JIT(unittest.TestCase): ...@@ -256,7 +176,7 @@ class Test_JIT(unittest.TestCase):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
return F.mu_law_expanding(tensor, qc) return F.mu_law_expanding(tensor, qc)
tensor = torch.rand((10, 1)) tensor = torch.rand((1, 10))
qc = 256 qc = 256
jit_out = jit_method(tensor, qc) jit_out = jit_method(tensor, qc)
...@@ -266,7 +186,7 @@ class Test_JIT(unittest.TestCase): ...@@ -266,7 +186,7 @@ class Test_JIT(unittest.TestCase):
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawExpanding(self): def test_scriptmodule_MuLawExpanding(self):
tensor = torch.rand((10, 1), device="cuda") tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawExpanding) self._test_script_module(tensor, transforms.MuLawExpanding)
......
...@@ -19,191 +19,123 @@ if IMPORT_SCIPY: ...@@ -19,191 +19,123 @@ if IMPORT_SCIPY:
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
# create a sinewave signal for testing # create a sinewave signal for testing
sr = 16000 sample_rate = 16000
freq = 440 freq = 440
volume = .3 volume = .3
sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr)) waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
sig.unsqueeze_(1) # (64000, 1) waveform.unsqueeze_(0) # (1, 64000)
sig = (sig * volume * 2**31).long() waveform = (waveform * volume * 2**31).long()
# file for stereo stft test # file for stereo stft test
test_dirpath, test_dir = test.common_utils.create_temp_assets_dir() test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, "assets", test_filepath = os.path.join(test_dirpath, 'assets',
"steam-train-whistle-daniel_simon.mp3") 'steam-train-whistle-daniel_simon.mp3')
def test_scale(self): def scale(self, waveform, factor=float(2**31)):
# scales a waveform by a factor
audio_orig = self.sig.clone() if not waveform.is_floating_point():
result = transforms.Scale()(audio_orig) waveform = waveform.to(torch.get_default_dtype())
self.assertTrue(result.min() >= -1. and result.max() <= 1.) return waveform / factor
maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()
result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1.)
repr_test = transforms.Scale()
self.assertTrue(repr_test.__repr__())
def test_pad_trim(self): def test_pad_trim(self):
audio_orig = self.sig.clone() waveform = self.waveform.clone()
length_orig = audio_orig.size(0) length_orig = waveform.size(1)
length_new = int(length_orig * 1.2) length_new = int(length_orig * 1.2)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig) result = transforms.PadTrim(max_len=length_new)(waveform)
self.assertEqual(result.size(0), length_new)
result = transforms.PadTrim(max_len=length_new, channels_first=True)(audio_orig.transpose(0, 1))
self.assertEqual(result.size(1), length_new) self.assertEqual(result.size(1), length_new)
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 0.8) length_new = int(length_orig * 0.8)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig) result = transforms.PadTrim(max_len=length_new)(waveform)
self.assertEqual(result.size(1), length_new)
self.assertEqual(result.size(0), length_new)
repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
self.assertTrue(repr_test.__repr__())
def test_downmix_mono(self):
audio_L = self.sig.clone()
audio_R = self.sig.clone()
R_idx = int(audio_R.size(0) * 0.1)
audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx]))
audio_Stereo = torch.cat((audio_L, audio_R), dim=1)
self.assertTrue(audio_Stereo.size(1) == 2)
result = transforms.DownmixMono(channels_first=False)(audio_Stereo)
self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono(channels_first=False)
self.assertTrue(repr_test.__repr__())
def test_lc2cl(self):
audio = self.sig.clone()
result = transforms.LC2CL()(audio)
self.assertTrue(result.size()[::-1] == audio.size())
repr_test = transforms.LC2CL()
self.assertTrue(repr_test.__repr__())
def test_compose(self):
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2)
maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()
tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new, channels_first=False))
result = transforms.Compose(tset)(audio_orig)
self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.)
self.assertTrue(result.size(0) == length_new)
repr_test = transforms.Compose(tset)
self.assertTrue(repr_test.__repr__())
def test_mu_law_companding(self): def test_mu_law_companding(self):
quantization_channels = 256 quantization_channels = 256
sig = self.sig.clone() waveform = self.waveform.clone()
sig = sig / torch.abs(sig).max() waveform /= torch.abs(waveform).max()
self.assertTrue(sig.min() >= -1. and sig.max() <= 1.) self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
sig_mu = transforms.MuLawEncoding(quantization_channels)(sig)
self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels)
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu) waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.) self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
repr_test = transforms.MuLawEncoding(quantization_channels) waveform_exp = transforms.MuLawExpanding(quantization_channels)(waveform_mu)
self.assertTrue(repr_test.__repr__()) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
repr_test = transforms.MuLawExpanding(quantization_channels)
self.assertTrue(repr_test.__repr__())
def test_mel2(self): def test_mel2(self):
top_db = 80. top_db = 80.
s2db = transforms.SpectrogramToDB("power", top_db) s2db = transforms.SpectrogramToDB('power', top_db)
audio_orig = self.sig.clone() # (16000, 1) waveform = self.waveform.clone() # (1, 16000)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) waveform_scaled = self.scale(waveform) # (1, 16000)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
mel_transform = transforms.MelSpectrogram() mel_transform = transforms.MelSpectrogram()
# check defaults # check defaults
spectrogram_torch = s2db(mel_transform(audio_scaled)) # (1, 319, 40) spectrogram_torch = s2db(mel_transform(waveform_scaled)) # (1, 128, 321)
self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels) self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
# check correctness of filterbank conversion matrix # check correctness of filterbank conversion matrix
self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all()) self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
# check options # check options
kwargs = {"window": torch.hamming_window, "pad": 10, "ws": 500, "hop": 125, "n_fft": 800, "n_mels": 50} kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
mel_transform2 = transforms.MelSpectrogram(**kwargs) mel_transform2 = transforms.MelSpectrogram(**kwargs)
spectrogram2_torch = s2db(mel_transform2(audio_scaled)) # (1, 506, 50) spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513)
self.assertTrue(spectrogram2_torch.dim() == 3) self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels) self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
# check on multi-channel audio # check on multi-channel audio
x_stereo, sr_stereo = torchaudio.load(self.test_filepath) x_stereo, sr_stereo = torchaudio.load(self.test_filepath) # (2, 278756), 44100
spectrogram_stereo = s2db(mel_transform(x_stereo)) spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394)
self.assertTrue(spectrogram_stereo.dim() == 3) self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2) self.assertTrue(spectrogram_stereo.size(0) == 2)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels) self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
# check filterbank matrix creation # check filterbank matrix creation
fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400) fb_matrix_transform = transforms.MelScale(
n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_mfcc(self): def test_mfcc(self):
audio_orig = self.sig.clone() audio_orig = self.waveform.clone()
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) audio_scaled = self.scale(audio_orig) # (1, 16000)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
sample_rate = 16000 sample_rate = 16000
n_mfcc = 40 n_mfcc = 40
n_mels = 128 n_mels = 128
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm='ortho') norm='ortho')
# check defaults # check defaults
torch_mfcc = mfcc_transform(audio_scaled) torch_mfcc = mfcc_transform(audio_scaled) # (1, 40, 321)
self.assertTrue(torch_mfcc.dim() == 3) self.assertTrue(torch_mfcc.dim() == 3)
self.assertTrue(torch_mfcc.shape[2] == n_mfcc) self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
self.assertTrue(torch_mfcc.shape[1] == 321) self.assertTrue(torch_mfcc.shape[2] == 321)
# check melkwargs are passed through # check melkwargs are passed through
melkwargs = {'ws': 200} melkwargs = {'win_length': 200}
mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate, mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm='ortho', norm='ortho',
melkwargs=melkwargs) melkwargs=melkwargs)
torch_mfcc2 = mfcc_transform2(audio_scaled) torch_mfcc2 = mfcc_transform2(audio_scaled) # (1, 40, 641)
self.assertTrue(torch_mfcc2.shape[1] == 641) self.assertTrue(torch_mfcc2.shape[2] == 641)
# check norms work correctly # check norms work correctly
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate, mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm=None) norm=None)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) # (1, 40, 321)
norm_check = torch_mfcc.clone() norm_check = torch_mfcc.clone()
norm_check[:, :, 0] *= math.sqrt(n_mels) * 2 norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2 norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
...@@ -212,45 +144,45 @@ class Tester(unittest.TestCase): ...@@ -212,45 +144,45 @@ class Tester(unittest.TestCase):
def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path) sound, sample_rate = torchaudio.load(input_path)
sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first sound_librosa = sound.cpu().numpy().squeeze() # (64000)
# test core spectrogram # test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2) spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa, out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
n_fft=n_fft, n_fft=n_fft,
hop_length=hop_length, hop_length=hop_length,
power=2) power=2)
out_torch = spect_transform(sound).squeeze().cpu().t() out_torch = spect_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)) self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))
# test mel spectrogram # test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window, melspect_transform = torchaudio.transforms.MelSpectrogram(
hop=hop_length, n_mels=n_mels, n_fft=n_fft) sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate, librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
htk=True, norm=None) htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel) librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu().t() torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
# test s2db # test s2db
db_transform = torchaudio.transforms.SpectrogramToDB('power', 80.)
db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.) db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t()
db_librosa = librosa.core.spectrum.power_to_db(out_librosa) db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3)) self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))
db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t() db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa) db_librosa_tensor = torch.from_numpy(db_librosa)
self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)) self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
# test MFCC # test MFCC
melkwargs = {'hop': hop_length, 'n_fft': n_fft} melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm='ortho', norm='ortho',
melkwargs=melkwargs) melkwargs=melkwargs)
...@@ -271,7 +203,7 @@ class Tester(unittest.TestCase): ...@@ -271,7 +203,7 @@ class Tester(unittest.TestCase):
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc] librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc) librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu().t() torch_mfcc = mfcc_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3)) self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
...@@ -308,27 +240,27 @@ class Tester(unittest.TestCase): ...@@ -308,27 +240,27 @@ class Tester(unittest.TestCase):
def test_resample_size(self): def test_resample_size(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path) waveform, sample_rate = torchaudio.load(input_path)
upsample_rate = sample_rate * 2 upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2 downsample_rate = sample_rate // 2
invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo') invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')
self.assertRaises(ValueError, invalid_resample, sound) self.assertRaises(ValueError, invalid_resample, waveform)
upsample_resample = torchaudio.transforms.Resample( upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method='sinc_interpolation') sample_rate, upsample_rate, resampling_method='sinc_interpolation')
up_sampled = upsample_resample(sound) up_sampled = upsample_resample(waveform)
# we expect the upsampled signal to have twice as many samples # we expect the upsampled signal to have twice as many samples
self.assertTrue(up_sampled.size(-1) == sound.size(-1) * 2) self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
downsample_resample = torchaudio.transforms.Resample( downsample_resample = torchaudio.transforms.Resample(
sample_rate, downsample_rate, resampling_method='sinc_interpolation') sample_rate, downsample_rate, resampling_method='sinc_interpolation')
down_sampled = downsample_resample(sound) down_sampled = downsample_resample(waveform)
# we expect the downsampled signal to have half as many samples # we expect the downsampled signal to have half as many samples
self.assertTrue(down_sampled.size(-1) == sound.size(-1) // 2) self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment